Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c982e753
Commit
c982e753
authored
Aug 18, 2021
by
Jing Zhang
Browse files
add make c into xldops-gemm
parent
62ebdfde
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
69 deletions
+54
-69
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+23
-31
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+1
-1
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+10
-17
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
...on_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
+20
-20
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
c982e753
...
@@ -40,14 +40,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -40,14 +40,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2ThreadDesc
()
{
constexpr
auto
M0
=
Number
<
CXdlopsLayout
.
M1
()
>
{};
constexpr
auto
M2
=
Number
<
CXdlopsLayout
.
M0
()
>
{};
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
M0
,
I1
,
M2
,
I1
));
}
__device__
static
auto
GetWaveIdx
()
__device__
static
auto
GetWaveIdx
()
{
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
thread_id
=
get_thread_local_1d_id
();
...
@@ -131,39 +123,39 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -131,39 +123,39 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_assert
(
NumBlks
==
1
&&
NumXdlops
==
1
,
"K Reduction Mfma only"
);
static_assert
(
NumBlks
==
1
&&
NumXdlops
==
1
,
"K Reduction Mfma only"
);
}
}
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2
Block
Descriptor
()
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2
Thread
Descriptor
()
{
{
constexpr
auto
M2
=
Number
<
CXdlopsLayout
.
M1
()
>
{};
constexpr
auto
M0
=
Number
<
CXdlopsLayout
.
M1
()
>
{};
constexpr
auto
M3
=
Number
<
CXdlopsLayout
.
N1
()
>
{};
constexpr
auto
M2
=
Number
<
CXdlopsLayout
.
M0
()
>
{};
constexpr
auto
M4
=
Number
<
CXdlopsLayout
.
M0
()
>
{};
constexpr
auto
N2
=
Number
<
CXdlopsLayout
.
N0
()
>
{};
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
M0
,
I1
,
M2
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2BlockDescriptor
()
{
constexpr
auto
c_m0_n0_m1_n1_m2_n2_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
NWaves
>
{},
Number
<
M
2
>
{},
Number
<
M
PerXDL
>
{},
Number
<
M3
>
{},
Number
<
NPerXDL
>
{}));
Number
<
M4
>
{},
Number
<
N2
>
{})
);
return
xdlops_gemm
.
MakeCM0N0M1N1M2M3M4N2Descriptor
(
c_m0_n0_m1_n1_m2_n2_block_desc
);
}
}
template
<
typename
CMNGridDesc
>
template
<
typename
CMNGridDesc
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
{
///\To-do: pass CGrid desc transform deep inside xdlops gemm
const
auto
c_m0_n0_m1_n1_m2_n2_grid_desc
=
transform_tensor_descriptor
(
constexpr
auto
M2
=
Number
<
CXdlopsLayout
.
M1
()
>
{};
constexpr
auto
M3
=
Number
<
CXdlopsLayout
.
N1
()
>
{};
constexpr
auto
M4
=
Number
<
CXdlopsLayout
.
M0
()
>
{};
constexpr
auto
N2
=
Number
<
CXdlopsLayout
.
N0
()
>
{};
return
transform_tensor_descriptor
(
c_m_n_grid_desc
,
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
M
2
,
M3
,
M4
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
M
PerXDL
)),
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
N
2
))),
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
N
PerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
xdlops_gemm
.
MakeCM0N0M1N1M2M3M4N2Descriptor
(
c_m0_n0_m1_n1_m2_n2_grid_desc
);
}
}
__host__
__device__
static
constexpr
auto
MakeAK0M0M1M2K1BlockDescriptor
()
__host__
__device__
static
constexpr
auto
MakeAK0M0M1M2K1BlockDescriptor
()
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
c982e753
...
@@ -376,7 +376,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -376,7 +376,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
blockwise_gemm
.
GetCM0N0M1N1M2M3M4N2ThreadDesc
();
blockwise_gemm
.
GetCM0N0M1N1M2M3M4N2ThreadDesc
riptor
();
constexpr
auto
CBlkSize
=
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
.
GetElementSpaceSize
();
constexpr
auto
CBlkSize
=
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
.
GetElementSpaceSize
();
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
c982e753
...
@@ -690,24 +690,17 @@ struct XdlopsGemm
...
@@ -690,24 +690,17 @@ struct XdlopsGemm
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"
);
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"
);
}
}
template
<
typename
CM0N0M1N1M2N2
Grid
Desc
>
template
<
typename
CM0N0M1N1M2N2Desc
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CM0N0M1N1M2N2GridDesc
&
c_m0_n0_m1_n1_m2_n2_grid_desc
)
MakeCM0N0M1N1M2M3M4N2Descriptor
(
const
CM0N0M1N1M2N2Desc
&
c_m0_n0_m1_n1_m2_n2_desc
)
{
{
constexpr
auto
M0
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I0
);
const
auto
M0
=
c_m0_n0_m1_n1_m2_n2_desc
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I1
);
const
auto
N0
=
c_m0_n0_m1_n1_m2_n2_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I2
);
const
auto
M1
=
c_m0_n0_m1_n1_m2_n2_desc
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I3
);
const
auto
N1
=
c_m0_n0_m1_n1_m2_n2_desc
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I4
);
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I5
);
return
transform_tensor_descriptor
(
c_m0_n0_m1_n1_m2_n2_desc
,
static_assert
(
N2
==
mfma_type
.
num_threads_per_blk
,
""
);
static_assert
(
M2
==
(
mfma_type
.
num_groups_per_blk
*
mfma_type
.
num_output_blks
*
mfma_type
.
group_size
),
""
);
return
transform_dynamic_tensor_descriptor
(
c_m0_n0_m1_n1_m2_n2_grid_desc
,
make_tuple
(
make_pass_through_transform
(
M0
),
make_tuple
(
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
M1
),
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
View file @
c982e753
...
@@ -48,10 +48,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
...
@@ -48,10 +48,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
out_n_k_ho_wo_lengths
);
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
out_n_k_ho_wo_lengths
);
#if 1
#if 1
// [M, N, K0, K1] = [
256
, 128, 4, 8] for fp16
// [M, N, K0, K1] = [
128
, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmKPerBlock
=
4
;
...
@@ -59,10 +59,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
...
@@ -59,10 +59,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
...
@@ -106,22 +106,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
...
@@ -106,22 +106,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
constexpr
auto
out_m0_m1_m2_n_grid_step_hacks
=
constexpr
auto
out_m0_m1_m2_n_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment