Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c811a0e9
Commit
c811a0e9
authored
Feb 16, 2023
by
aska-0096
Browse files
temp save, add asm backend flag to amd_wmma
parent
c749c262
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
26 additions
and
19 deletions
+26
-19
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
...ple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+4
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+3
-1
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+6
-5
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+11
-9
No files found.
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
View file @
c811a0e9
...
@@ -53,13 +53,13 @@ using DeviceConvFwdInstance =
...
@@ -53,13 +53,13 @@ using DeviceConvFwdInstance =
GemmSpec
,
// GemmSpecialization
GemmSpec
,
// GemmSpecialization
256
,
// BlockSize
256
,
// BlockSize
128
,
// MPerBlock
128
,
// MPerBlock
128
,
// NPerBlock
256
,
// NPerBlock
4
,
// K0PerBlock
4
,
// K0PerBlock
8
,
// K1
8
,
// K1
16
,
// MPerWMMA
16
,
// MPerWMMA
16
,
// NPerWMMA
16
,
// NPerWMMA
4
,
// MRepeat
4
,
// MRepeat
2
,
// NRepeat
4
,
// NRepeat
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_AK0_M_AK1
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_AK0_M_AK1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
c811a0e9
...
@@ -375,7 +375,9 @@ template <index_t BlockSize,
...
@@ -375,7 +375,9 @@ template <index_t BlockSize,
index_t
NPerWMMA
,
index_t
NPerWMMA
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
>
index_t
KPack
,
bool
TransposeC
=
false
,
bool
AssemblyBackend
=
true
>
/* A: K0PerBlock x MPerBlock x K1
/* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
...
@@ -406,7 +408,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
...
@@ -406,7 +408,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
wmma_gemm
=
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
>
{};
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
,
AssemblyBackend
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
c811a0e9
...
@@ -683,7 +683,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -683,7 +683,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
NPerWmma
,
NPerWmma
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
{};
KPack
,
false
,
true
>
{};
// Prepare Register for C matrix
// Prepare Register for C matrix
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
c811a0e9
...
@@ -103,12 +103,12 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
...
@@ -103,12 +103,12 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
bool
AssemblyBackend
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
if
constexpr
(
wave_size
==
32
)
if
constexpr
(
wave_size
==
32
)
{
{
intrin_wmma_f32_16x16x16_f16_w32
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
intrin_wmma_f32_16x16x16_f16_w32
<
MPerWmma
,
NPerWmma
,
AssemblyBackend
>::
Run
(
a
,
b
,
reg_c
);
}
}
else
if
constexpr
(
wave_size
==
64
)
else
if
constexpr
(
wave_size
==
64
)
{
{
...
@@ -358,7 +358,8 @@ template <typename src_type_a,
...
@@ -358,7 +358,8 @@ template <typename src_type_a,
index_t
MPerWmma
,
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
NPerWmma
,
index_t
KPack
,
index_t
KPack
,
bool
TransposeC
=
false
>
bool
TransposeC
=
false
,
bool
AssemblyBackend
=
false
>
struct
WmmaGemm
struct
WmmaGemm
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -491,11 +492,11 @@ struct WmmaGemm
...
@@ -491,11 +492,11 @@ struct WmmaGemm
"(int8, int32) or (int4, int32)!"
);
"(int8, int32) or (int4, int32)!"
);
if
constexpr
(
!
TransposeC
)
if
constexpr
(
!
TransposeC
)
{
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_a_wave
,
p_b_wave
,
p_c_thread
);
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
,
AssemblyBackend
>(
p_a_wave
,
p_b_wave
,
p_c_thread
);
}
}
else
else
{
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_b_wave
,
p_a_wave
,
p_c_thread
);
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
,
AssemblyBackend
>(
p_b_wave
,
p_a_wave
,
p_c_thread
);
}
}
}
}
...
...
include/ck/utility/amd_wmma.hpp
View file @
c811a0e9
...
@@ -12,22 +12,24 @@ namespace ck {
...
@@ -12,22 +12,24 @@ namespace ck {
/********************************WAVE32 MODE***********************************************/
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32
// src: fp16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
,
bool
AssemblyBackend
>
struct
intrin_wmma_f32_16x16x16_f16_w32
;
struct
intrin_wmma_f32_16x16x16_f16_w32
;
template
<
>
template
<
bool
AssemblyBackend
>
struct
intrin_wmma_f32_16x16x16_f16_w32
<
16
,
16
>
struct
intrin_wmma_f32_16x16x16_f16_w32
<
16
,
16
,
AssemblyBackend
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
if
constexpr
(
AssemblyBackend
){
// delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32
(
// amd_assembly_wmma_f32_16x16x16_f16_w32(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{}));
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
}
else
{
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
}
}
}
};
};
// src: bf16, dst: fp32
// src: bf16, dst: fp32
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment