Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
35e49f2d
Unverified
Commit
35e49f2d
authored
Aug 12, 2022
by
zjing14
Committed by
GitHub
Aug 12, 2022
Browse files
add g; fixed strides (#355)
parent
de60d290
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
279 additions
and
250 deletions
+279
-250
example/25_gemm_bias_e_permute/CMakeLists.txt
example/25_gemm_bias_e_permute/CMakeLists.txt
+2
-2
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
..._bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
+131
-111
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp
..._bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp
+141
-128
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...ce/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+5
-9
No files found.
example/25_gemm_bias_e_permute/CMakeLists.txt
View file @
35e49f2d
add_example_executable
(
example_gemm_bias_e_permute_m3n2_xdl_fp16 gemm_bias_e_permute_m3n2_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_bias_e_permute_
g1
m3n2
k1
_xdl_fp16 gemm_bias_e_permute_
g1
m3n2
k1
_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_bias_e_permute_m2n3_xdl_fp16 gemm_bias_e_permute_m2n3_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_bias_e_permute_
g1
m2n3
k1
_xdl_fp16 gemm_bias_e_permute_
g1
m2n3
k1
_xdl_fp16.cpp
)
example/25_gemm_bias_e_permute/gemm_bias_e_permute_m2n3_xdl_fp16.cpp
→
example/25_gemm_bias_e_permute/gemm_bias_e_permute_
g1
m2n3
k1
_xdl_fp16.cpp
View file @
35e49f2d
This diff is collapsed.
Click to expand it.
example/25_gemm_bias_e_permute/gemm_bias_e_permute_m3n2_xdl_fp16.cpp
→
example/25_gemm_bias_e_permute/gemm_bias_e_permute_
g1
m3n2
k1
_xdl_fp16.cpp
View file @
35e49f2d
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
35e49f2d
...
@@ -500,11 +500,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
...
@@ -500,11 +500,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
std
::
array
<
long_index_t
,
NumDTensor
>
ds_offset
;
std
::
array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
NumDimG
>
0
)
ds_offset
[
i
]
=
ds_offset
[
i
]
=
ds_grid_desc_g_m_n_
[
i
].
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
ds_grid_desc_g_m_n_
[
i
].
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
else
ds_offset
[
i
]
=
0
;
});
});
return
ds_offset
;
return
ds_offset
;
...
@@ -512,10 +509,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
...
@@ -512,10 +509,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
{
if
constexpr
(
NumDimG
>
0
)
return
e_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
e_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
else
return
0
;
}
}
private:
private:
...
@@ -634,6 +628,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
...
@@ -634,6 +628,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
compute_ptr_offset_of_batch_
{
compute_ptr_offset_of_batch_
{
a_batch_stride_
,
b_batch_stride_
,
ds_grid_desc_g_m_n_
,
e_grid_desc_g_m_n_
}
a_batch_stride_
,
b_batch_stride_
,
ds_grid_desc_g_m_n_
,
e_grid_desc_g_m_n_
}
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
,
""
);
// populate pointer, batch stride, desc for Ds
// populate pointer, batch stride, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment