Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
02bf2be0
Commit
02bf2be0
authored
May 18, 2021
by
Jing Zhang
Browse files
clean code
parent
dfbe7e20
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
87 additions
and
51 deletions
+87
-51
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+12
-1
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+6
-6
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+12
-10
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+11
-20
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+13
-6
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+33
-8
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
02bf2be0
...
...
@@ -13,6 +13,9 @@ namespace ck {
// GemmK = C * Y * X
template
<
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmKPerWave
,
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
...
...
@@ -106,9 +109,17 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
float
,
GemmMPerWave
,
GemmNPerWave
,
GemmKPerWave
>
{};
constexpr
auto
CLayout
=
xdlops_gemm
.
GetOutputLayout
();
constexpr
index_t
M0
=
CLayout
.
M1
();
constexpr
index_t
M1
=
CLayout
.
N1
();
constexpr
index_t
M2
=
CLayout
.
M0
();
const
auto
out_m0_m1_m2_n_global_desc
=
transform_dynamic_tensor_descriptor
(
out_gemmm_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM
/
8
,
2
,
4
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM
/
(
M1
*
M2
),
M1
,
M2
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}));
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
02bf2be0
...
...
@@ -26,7 +26,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
X
dlops
G
emm
=
XdlopsGemm
_t
<
float
,
MPerWave
,
NPerWave
,
KPerWave
>
{};
static
constexpr
auto
x
dlops
_g
emm
=
XdlopsGemm
<
float
,
MPerWave
,
NPerWave
,
KPerWave
>
{};
static
constexpr
index_t
WaveSize
=
64
;
...
...
@@ -35,16 +35,16 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static
constexpr
index_t
MWaves
=
MPerBlock
/
MPerWave
;
static
constexpr
index_t
NWaves
=
NPerBlock
/
NPerWave
;
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
X
dlops
G
emm
.
GetOutputLayout
();
}
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
x
dlops
_g
emm
.
GetOutputLayout
();
}
__device__
constexpr
auto
GetNumBlks
()
const
{
return
X
dlops
G
emm
.
GetOutputLayout
().
GetNumBlks
();
return
x
dlops
_g
emm
.
GetOutputLayout
().
GetNumBlks
();
}
__device__
constexpr
auto
GetBlkSize
()
const
{
return
X
dlops
G
emm
.
GetOutputLayout
().
GetBlkSize
();
return
x
dlops
_g
emm
.
GetOutputLayout
().
GetBlkSize
();
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
...
...
@@ -75,7 +75,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
auto
thread_mtx_on_blk
=
X
dlops
G
emm
.
GetBeginOfThreadBlk
(
blk_i
);
const
auto
thread_mtx_on_blk
=
x
dlops
_g
emm
.
GetBeginOfThreadBlk
(
blk_i
);
const
index_t
row
=
(
waveId
/
NWaves
)
*
AStride
+
thread_mtx_on_blk
.
row
;
const
index_t
col
=
(
waveId
%
NWaves
)
*
BStride
+
thread_mtx_on_blk
.
col
;
...
...
@@ -127,7 +127,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
make_tuple
(
I0
,
I0
),
b_thread_buf
);
X
dlops
G
emm
.
template
Run
(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
x
dlops
_g
emm
.
template
Run
(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
});
}
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
02bf2be0
...
...
@@ -333,7 +333,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
vector_type
<
float
,
64
>
c_thread_buf
;
constexpr
auto
c_vec_size
=
MPerBlock
*
NPerBlock
/
BlockSize
;
vector_type
<
float
,
c_vec_size
>
c_thread_buf
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
...
...
@@ -466,15 +468,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
{
constexpr
auto
OutputLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
index_t
K
0
=
OutputLayout
.
M1
();
constexpr
index_t
K
1
=
OutputLayout
.
N1
();
constexpr
index_t
K
2
=
OutputLayout
.
M0
();
constexpr
index_t
M
0
=
OutputLayout
.
M1
();
constexpr
index_t
M
1
=
OutputLayout
.
N1
();
constexpr
index_t
M
2
=
OutputLayout
.
M0
();
// static_assert(
K
0 == 4 &&
K
1 == 2 &&
K
2 == 4, "");
// static_assert(
M
0 == 4 &&
M
1 == 2 &&
M
2 == 4, "");
constexpr
auto
c_m0_m1_m2_n_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
K
0
>
{},
Number
<
1
>
{},
Number
<
K
2
>
{},
Number
<
1
>
{}));
make_tuple
(
Number
<
M
0
>
{},
Number
<
1
>
{},
Number
<
M
2
>
{},
Number
<
1
>
{}));
constexpr
index_t
BlkSize
=
OutputLayout
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
OutputLayout
.
GetNumBlks
();
...
...
@@ -508,16 +510,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
FloatC
,
decltype
(
c_m0_m1_m2_n_thread_desc
),
decltype
(
c_m0_m1_m2_n_global_desc
),
Sequence
<
K
0
,
1
,
K
2
,
1
>
,
Sequence
<
M
0
,
1
,
M
2
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
// CThreadTransferSrcDstAccessOrder,
3
,
// CThreadTransferSrcDstVectorDim,
1
,
// CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_m0_m1_m2_n_global_desc
,
make_multi_index
(
k_thread_data_on_global
/
(
K
2
*
K
1
),
k_thread_data_on_global
%
(
K
2
*
K
1
)
/
K
2
,
k_thread_data_on_global
%
K
2
,
make_multi_index
(
k_thread_data_on_global
/
(
M
2
*
M
1
),
k_thread_data_on_global
%
(
M
2
*
M
1
)
/
M
2
,
k_thread_data_on_global
%
M
2
,
b_thread_data_on_global
)}
.
Run
(
c_m0_m1_m2_n_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
02bf2be0
...
...
@@ -53,7 +53,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
return
intrin_mfma_f32_32x32x1f32
<
MPerXdlops
,
NPerXdlops
>::
r
un
(
a
,
b
,
reg_c
);
return
intrin_mfma_f32_32x32x1f32
<
MPerXdlops
,
NPerXdlops
>::
R
un
(
a
,
b
,
reg_c
);
}
};
...
...
@@ -74,19 +74,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
1
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_32x32x2f32
(
p_a
,
p_b
,
reg_c
);
return
intrin_mfma_f32_32x32x2f32
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
...
...
@@ -548,7 +539,7 @@ struct xdlops_info
};
template
<
class
data_type
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPerWave
>
struct
XdlopsGemm
_t
struct
XdlopsGemm
{
struct
MatrixIndex
{
...
...
@@ -561,7 +552,7 @@ struct XdlopsGemm_t
return
(
MPerXdlops
*
NPerXdlops
)
/
(
mfma_type
.
m
*
mfma_type
.
n
);
}
__device__
constexpr
XdlopsGemm
_t
()
__host__
__device__
constexpr
XdlopsGemm
()
{
static_assert
(
NPerXdlops
==
4
||
NPerXdlops
==
8
||
NPerXdlops
==
16
||
NPerXdlops
==
32
||
NPerXdlops
==
64
,
...
...
@@ -849,10 +840,10 @@ struct XdlopsGemm_t
struct
OutputLayout
{
__device__
static
constexpr
index_t
M1
()
{
return
mfma_type
.
num_groups_blk
;
}
__device__
static
constexpr
index_t
M0
()
{
return
mfma_type
.
group_size
;
}
__device__
static
constexpr
index_t
N1
()
{
return
mfma_type
.
num_input_blks
;
}
__device__
static
constexpr
index_t
N0
()
{
return
mfma_type
.
num_threads_blk
;
}
__host__
__device__
static
constexpr
index_t
M1
()
{
return
mfma_type
.
num_groups_blk
;
}
__host__
__device__
static
constexpr
index_t
M0
()
{
return
mfma_type
.
group_size
;
}
__host__
__device__
static
constexpr
index_t
N1
()
{
return
mfma_type
.
num_input_blks
;
}
__host__
__device__
static
constexpr
index_t
N0
()
{
return
mfma_type
.
num_threads_blk
;
}
__device__
static
constexpr
index_t
GetBlkSize
()
{
return
mfma_type
.
num_regs_blk
;
}
...
...
@@ -867,7 +858,7 @@ struct XdlopsGemm_t
}
};
__device__
static
constexpr
auto
GetOutputLayout
()
{
return
OutputLayout
{};
}
__host__
__device__
static
constexpr
auto
GetOutputLayout
()
{
return
OutputLayout
{};
}
};
}
// namespace ck
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
02bf2be0
...
...
@@ -241,7 +241,7 @@ template <>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
>
{
__device__
static
void
r
un
(
const
float
&
reg_a
,
const
float
&
reg_b
,
vector_type
<
float
,
64
>&
reg_c
)
R
un
(
const
float
&
reg_a
,
const
float
&
reg_b
,
vector_type
<
float
,
64
>&
reg_c
)
{
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
...
...
@@ -272,12 +272,19 @@ struct intrin_mfma_f32_32x32x1f32<64, 64>
//}
//};
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_32x32x2f32
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x2f32
;
template
<
>
struct
intrin_mfma_f32_32x32x2f32
<
32
,
32
>
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
0
);
return
reg_c
;
}
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
vector_type
<
float
,
16
>&
reg_c
)
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
__device__
c_vec4_1_t
::
VecType
intrin_mfma_f32_16x16x4f32
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec4_1_t
::
VecType
reg_c
)
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
02bf2be0
...
...
@@ -71,13 +71,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
const
auto
out_n_k_ho_wo_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
OutDesc
::
GetLengths
()));
const
auto
conv_strides
=
sequence_to_tuple_of_number
(
ConvStrides
{});
const
auto
conv_dilations
=
sequence_to_tuple_of_number
(
ConvDilations
{});
const
auto
in_left_pads
=
sequence_to_tuple_of_number
(
InLeftPads
{});
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
const
auto
conv_strides
=
sequence_to_tuple_of_number
(
ConvStrides
{});
const
auto
conv_dilations
=
sequence_to_tuple_of_number
(
ConvDilations
{});
const
auto
in_left_pads
=
sequence_to_tuple_of_number
(
InLeftPads
{});
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
#endif
// b thread copy 4x1
#if 0
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 64;
...
...
@@ -101,13 +101,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#else
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
32
;
constexpr
index_t
GemmNPerBlock
=
32
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmKPerWave
=
2
;
constexpr
index_t
GemmM1
=
GemmMPerWave
;
constexpr
index_t
GemmN1
=
GemmNPerWave
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
32
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
32
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#endif
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
<
GemmMPerBlock
,
GemmNPerBlock
>
(
GemmNPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmKPerWave
>
(
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment