Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5bf77d8b
"git@developer.sourcefind.cn:OpenDAS/torch-spline-conv.git" did not exist on "dff93289bcc7673ed46d5192eb73196e3fe31efc"
Commit
5bf77d8b
authored
May 10, 2023
by
aska-0096
Browse files
clang-format
parent
2f88070a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
27 additions
and
38 deletions
+27
-38
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+6
-13
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+4
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+10
-17
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
5bf77d8b
...
@@ -211,27 +211,20 @@ struct BlockwiseGemmWMMA
...
@@ -211,27 +211,20 @@ struct BlockwiseGemmWMMA
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
constexpr
auto
AccStride
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I3
];
constexpr
auto
AccStride
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
MAccVgprs
),
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
MAccVgprs
),
make_tuple
(
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
make_tuple
(
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
AccStride
)
AccStride
));
);
#if 0
#if 0
return make_naive_tensor_descriptor_packed(
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
// |NThreadPerSubGroup |MAccVgprs
...
@@ -242,7 +235,7 @@ struct BlockwiseGemmWMMA
...
@@ -242,7 +235,7 @@ struct BlockwiseGemmWMMA
I1,
I1,
NThreadPerSubGroup,
NThreadPerSubGroup,
MAccVgprs));
MAccVgprs));
#endif
#endif
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
5bf77d8b
...
@@ -151,9 +151,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -151,9 +151,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
constexpr
auto
B0EnableLds_manu
=
true
;
static
constexpr
auto
B0EnableLds_manu
=
true
;
static
constexpr
auto
B1EnableLds_manu
=
true
;
static
constexpr
auto
B1EnableLds_manu
=
true
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
B0EnableLds
=
B0EnableLds_auto
||
B0EnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
B0EnableLds
=
B0EnableLds_auto
||
B0EnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
B1EnableLds
=
B1EnableLds_auto
||
B1EnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
B1EnableLds
=
B1EnableLds_auto
||
B1EnableLds_manu
||
(
NumPrefetch
>
1
);
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
>
,
Sequence
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
>
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
5bf77d8b
...
@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
...
@@ -467,7 +467,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -467,7 +467,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
)
ck
::
get_device_name
()
==
"gfx1102"
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
{
printf
(
"DeviceOp err: AccDataType"
);
printf
(
"DeviceOp err: AccDataType"
);
return
false
;
return
false
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
5bf77d8b
...
@@ -177,8 +177,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -177,8 +177,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumGemmKPrefetchStage
>
1
);
static
constexpr
auto
AEnableLds
=
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumGemmKPrefetchStage
>
1
);
AEnableLds_auto
||
AEnableLds_manu
||
(
NumGemmKPrefetchStage
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumGemmKPrefetchStage
>
1
);
static
constexpr
auto
conv_to_gemm_transformer
=
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
5bf77d8b
...
@@ -104,11 +104,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
...
@@ -104,11 +104,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
if
constexpr
(
wave_size
==
32
)
if
constexpr
(
wave_size
==
32
)
...
@@ -142,7 +138,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
...
@@ -142,7 +138,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
...
@@ -182,11 +178,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
...
@@ -182,11 +178,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
if
constexpr
(
wave_size
==
32
)
if
constexpr
(
wave_size
==
32
)
...
@@ -261,7 +253,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
...
@@ -261,7 +253,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
template
<
index_t
MPerWmma
,
...
@@ -496,13 +488,11 @@ struct WmmaGemm
...
@@ -496,13 +488,11 @@ struct WmmaGemm
"(int8, int32) or (int4, int32)!"
);
"(int8, int32) or (int4, int32)!"
);
if
constexpr
(
!
TransposeC
)
if
constexpr
(
!
TransposeC
)
{
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_a_wave
,
p_b_wave
,
p_c_thread
);
p_a_wave
,
p_b_wave
,
p_c_thread
);
}
}
else
else
{
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_b_wave
,
p_a_wave
,
p_c_thread
);
p_b_wave
,
p_a_wave
,
p_c_thread
);
}
}
}
}
...
@@ -555,7 +545,10 @@ struct WmmaGemm
...
@@ -555,7 +545,10 @@ struct WmmaGemm
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
()
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
()
{
{
return
make_tuple
(
I1
,
I1
,
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{},
Number
<
wmma_instr
.
acc_pack_number
>
{});
return
make_tuple
(
I1
,
I1
,
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{},
Number
<
wmma_instr
.
acc_pack_number
>
{});
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment