Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
01876afa
Unverified
Commit
01876afa
authored
Sep 21, 2022
by
zjing14
Committed by
GitHub
Sep 21, 2022
Browse files
fixed G offset calc for long_index (#428)
parent
567f70f5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
5 deletions
+6
-5
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...ce/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+6
-5
No files found.
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
01876afa
...
...
@@ -506,12 +506,12 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
batch_stride_A_
)
;
return
static_cast
<
long_index_t
>
(
g_idx
)
*
batch_stride_A_
;
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
batch_stride_B_
)
;
return
static_cast
<
long_index_t
>
(
g_idx
)
*
batch_stride_B_
;
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
...
...
@@ -519,8 +519,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
std
::
array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_offset
[
i
]
=
ds_grid_desc_g_m_n_
[
i
].
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
ds_offset
[
i
]
=
static_cast
<
long_index_t
>
(
g_idx
)
*
ds_grid_desc_g_m_n_
[
i
].
CalculateOffset
(
make_multi_index
(
1
,
0
,
0
));
});
return
ds_offset
;
...
...
@@ -528,7 +528,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
e_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
static_cast
<
long_index_t
>
(
g_idx
)
*
e_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
1
,
0
,
0
));
}
private:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment