Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bdd0f64e
Commit
bdd0f64e
authored
Mar 06, 2023
by
aska-0096
Browse files
Fix a bug
parent
a045e0be
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
45 deletions
+43
-45
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
...d_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+1
-18
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+41
-26
No files found.
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
View file @
bdd0f64e
...
@@ -26,7 +26,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
...
@@ -26,7 +26,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
{
{
split_k
=
1
;
split_k
=
1
;
}
}
const
auto
in_g_n_c_wis_desc
=
const
auto
in_g_n_c_wis_desc
=
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
InputLayout
<
NDimSpatial
>>
(
conv_param
);
InputLayout
<
NDimSpatial
>>
(
conv_param
);
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
bdd0f64e
...
@@ -62,20 +62,6 @@ struct BlockwiseGemmWMMA
...
@@ -62,20 +62,6 @@ struct BlockwiseGemmWMMA
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I4
);
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I4
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I4
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I4
);
static
constexpr
auto
A_temp0
=
Number
<
ABlockDesc
{}.
GetLength
(
I0
)
>
{};
static
constexpr
auto
A_temp1
=
Number
<
ABlockDesc
{}.
GetLength
(
I1
)
>
{};
static
constexpr
auto
A_temp2
=
Number
<
ABlockDesc
{}.
GetLength
(
I2
)
>
{};
static
constexpr
auto
A_temp3
=
Number
<
ABlockDesc
{}.
GetLength
(
I3
)
>
{};
static
constexpr
auto
A_temp4
=
Number
<
ABlockDesc
{}.
GetLength
(
I4
)
>
{};
// FIX it, workaround
using
ABlockDesc_temp
=
decltype
(
make_naive_tensor_descriptor
(
make_tuple
(
A_temp0
,
A_temp1
,
A_temp2
,
A_temp3
,
A_temp4
),
make_tuple
(
A_temp1
*
A_temp2
*
A_temp3
*
A_temp4
,
A_temp2
*
A_temp3
*
A_temp4
,
A_temp3
*
A_temp4
,
A_temp4
,
I1
)));
static
constexpr
auto
wmma_gemm
=
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
...
@@ -210,9 +196,6 @@ struct BlockwiseGemmWMMA
...
@@ -210,9 +196,6 @@ struct BlockwiseGemmWMMA
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
// constexpr auto NSubGroup =
// c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; constexpr auto MThreadPerSubGroup
// = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr
auto
NAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
constexpr
auto
NAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
return
make_naive_tensor_descriptor_packed
(
...
@@ -302,7 +285,7 @@ struct BlockwiseGemmWMMA
...
@@ -302,7 +285,7 @@ struct BlockwiseGemmWMMA
// Describe how data allocated in thread copy src buffer
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static
constexpr
ABlockDesc
_temp
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
ABlockDesc
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
bdd0f64e
...
@@ -249,20 +249,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -249,20 +249,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I5
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I5
);
// Err: merge transform cause non-constexpr issue
// return transform_tensor_descriptor(
// ABlockDesc_{},
// make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
// make_pass_through_transform(Number<MRepeat>{}),
// make_pass_through_transform(I1),
// make_pass_through_transform(I1),
// make_pass_through_transform(Number<A_K1>{})),
// make_tuple(Sequence<0, 3>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},
// Sequence<4>{}));
// Workaround, Freeze transform
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
ABlockDesc_
{},
ABlockDesc_
{},
make_tuple
(
make_merge_transform
(
make_tuple
(
Number
<
KWmma
>
{},
I1
)),
make_tuple
(
make_freeze_transform
(
I0
),
make_pass_through_transform
(
Number
<
KWmma
>
{}),
make_pass_through_transform
(
Number
<
MRepeat
>
{}),
make_pass_through_transform
(
Number
<
MRepeat
>
{}),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
,
3
>
{},
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
Sequence
<
5
>
{}),
make_tuple
(
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}
}();
}();
...
@@ -455,14 +480,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -455,14 +480,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
static
constexpr
auto
a_block_space_size_aligned
=
static
constexpr
auto
a_block_space_size_aligned
=
AEnableLds
?
math
::
integer_least_multiple
(
MakeABlockDescriptor
().
GetElementSpaceSize
(),
AEnableLds
?
math
::
integer_least_multiple
(
MakeABlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
*
max_lds_align
)
sizeof
(
FloatA
)
:
0
;
:
0
;
static
constexpr
auto
b_block_space_size_aligned
=
static
constexpr
auto
b_block_space_size_aligned
=
BEnableLds
?
math
::
integer_least_multiple
(
BEnableLds
?
math
::
integer_least_multiple
(
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
().
GetElementSpaceSize
(),
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
().
GetElementSpaceSize
(),
max_lds_align
)
*
max_lds_align
)
sizeof
(
FloatB
)
:
0
;
:
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
...
@@ -471,13 +494,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -471,13 +494,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// LDS allocation for C shuffle in LDS
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_space_size
=
static
constexpr
auto
c_shuffle_block_space_size
=
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
.
GetElementSpaceSize
()
*
.
GetElementSpaceSize
();
sizeof
(
FloatCShuffle
);
static
constexpr
auto
c_shuffle_block_space_offset
=
0
;
static
constexpr
auto
c_shuffle_block_space_offset
=
0
;
static
constexpr
auto
lds_size
=
math
::
max
(
static
constexpr
auto
lds_size
=
c_shuffle_block_space_size
,
(
a_block_space_size_aligned
+
b_block_space_size_aligned
));
math
::
max
(
c_shuffle_block_space_size
*
sizeof
(
FloatCShuffle
),
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
b_block_space_size_aligned
*
sizeof
(
FloatB
));
};
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
...
@@ -528,8 +552,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -528,8 +552,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
}
}
}();
}();
// printf("---------------K = %d\n", K);
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
b_block_desc
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
b_block_desc
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
...
@@ -540,7 +562,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -540,7 +562,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
*>
(
p_shared
),
static_cast
<
FloatA
*>
(
p_shared
),
a_block_desc
.
GetElementS
pace
S
ize
()
);
SharedMemTrait
::
a_block_s
pace
_s
ize
_aligned
);
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
...
@@ -615,8 +637,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -615,8 +637,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
{
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB
*>
(
p_shared
)
+
SharedMemTrait
::
a
_block_space_
size_aligned
,
static_cast
<
FloatB
*>
(
p_shared
)
+
SharedMemTrait
::
b
_block_space_
offset
,
b_block_desc
.
GetElementS
pace
S
ize
()
);
SharedMemTrait
::
b_block_s
pace
_s
ize
_aligned
);
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
...
@@ -703,7 +725,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -703,7 +725,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/
/*******************************************************************************/
// Shift Per SUB_K
// Shift Per SUB_K
constexpr
auto
a_block_slice_copy_step
=
MakeABlockSliceCopyStep
();
constexpr
auto
a_block_slice_copy_step
=
MakeABlockSliceCopyStep
();
// printf("a_block_slice_copy_step FirstKdim = %d\n", a_block_slice_copy_step[I0]);
constexpr
auto
b_block_slice_copy_step
=
MakeBBlockSliceCopyStep
();
constexpr
auto
b_block_slice_copy_step
=
MakeBBlockSliceCopyStep
();
// gridwise GEMM pipeline
// gridwise GEMM pipeline
...
@@ -726,13 +747,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -726,13 +747,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/
/*******************************************************************************/
// write out to C, implement shuffle
// write out to C, implement shuffle
{
{
#if 0
static_for<0, c_thread_buf.Size(), 1>{}([&](auto i) {
printf("tid: %03d, c_thread_buf[%02d] val: %08x\n", get_thread_local_1d_id(), i.value,
*(reinterpret_cast<const uint32_t*>(&(c_thread_buf[i]))));
// c_thread_buf(i) = 32;
});
#endif
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
blockwise_gemm
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
blockwise_gemm
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
...
@@ -751,7 +765,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -751,7 +765,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatCShuffle
*>
(
p_shared
),
SharedMemTrait
::
c_shuffle_block_space_size
);
static_cast
<
FloatCShuffle
*>
(
p_shared
)
+
SharedMemTrait
::
c_shuffle_block_space_offset
,
SharedMemTrait
::
c_shuffle_block_space_size
);
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
transform_tensor_descriptor
(
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
,
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment