Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a52e5a92
Commit
a52e5a92
authored
Sep 16, 2021
by
ltqin
Browse files
finish driver_gemm_xdlops file
parent
a3b31a92
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
107 deletions
+107
-107
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+1
-1
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
+11
-10
host/driver_offline/include/driver_gemm_xdlops_v3r1.hpp
host/driver_offline/include/driver_gemm_xdlops_v3r1.hpp
+95
-96
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
a52e5a92
...
@@ -158,7 +158,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -158,7 +158,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return
xdlops_gemm
.
MakeCM0N0M1N1M2M3M4N2Descriptor
(
c_m0_n0_m1_n1_m2_n2_grid_desc
);
return
xdlops_gemm
.
MakeCM0N0M1N1M2M3M4N2Descriptor
(
c_m0_n0_m1_n1_m2_n2_grid_desc
);
}
}
template
<
typename
CGMNGridDesc
>
template
<
typename
CGMNGridDesc
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCGM0N0M1N1M2M3M4N2GridDescriptor
(
const
CGMNGridDesc
&
c_g_m_n_grid_desc
)
MakeCGM0N0M1N1M2M3M4N2GridDescriptor
(
const
CGMNGridDesc
&
c_g_m_n_grid_desc
)
{
{
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
View file @
a52e5a92
...
@@ -246,22 +246,23 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
...
@@ -246,22 +246,23 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
const
auto
N0
=
N
/
N1
;
const
auto
N0
=
N
/
N1
;
#if 1
#if 1
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
G
,
M0
,
N0
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
G
,
M0
,
N0
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
#elif 1
#elif 1
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
G
,
N0
,
M0
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
G
,
N0
,
M0
))),
make_tuple
(
Sequence
<
0
,
2
,
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
#endif
#endif
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
}
}
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCGM0N0M1N1M2M3M4N2GridDescriptor
(
CGMNGridDesc
{}));
using
CM0N0M1N1M2M3M4N2GridDesc
=
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CGMNGridDesc
{}));
decltype
(
MakeCGM0N0M1N1M2M3M4N2GridDescriptor
(
CGMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CGMNGridDesc
{}));
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
host/driver_offline/include/driver_gemm_xdlops_v3r1.hpp
View file @
a52e5a92
...
@@ -70,46 +70,46 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
...
@@ -70,46 +70,46 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
using
GridwiseGemm
=
using
GridwiseGemm
=
GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
<
BlockSize
,
GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
AGK0MK1GridDesc
,
AGK0MK1GridDesc
,
BGK0NK1GridDesc
,
BGK0NK1GridDesc
,
CGMNGridDesc
,
CGMNGridDesc
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
MPerXDL
,
MPerXDL
,
NPerXDL
,
NPerXDL
,
K1
,
K1
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_G_K0_M_K1
,
ABlockTransferThreadSliceLengths_G_K0_M_K1
,
ABlockTransferThreadClusterLengths_G_K0_M_K1
,
ABlockTransferThreadClusterLengths_G_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
BBlockTransferThreadSliceLengths_G_K0_N_K1
,
BBlockTransferThreadSliceLengths_G_K0_N_K1
,
BBlockTransferThreadClusterLengths_G_K0_N_K1
,
BBlockTransferThreadClusterLengths_G_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGridStepHacks
,
AGridStepHacks
,
BGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
CGridStepHacks
,
AGridMoveSliceWindowStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
CAccessOrderMRepeatNRepeat
>
;
CAccessOrderMRepeatNRepeat
>
;
{
{
std
::
cout
<<
"a_g_k0_m_k1_grid_desc{"
<<
a_g_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"a_g_k0_m_k1_grid_desc{"
<<
a_g_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
...
@@ -134,66 +134,65 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
...
@@ -134,66 +134,65 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
}
}
const
auto
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
const
auto
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
GridwiseGemm
::
MakeCGM0N0M1N1M2M3M4N2GridDescriptor
(
c_g_m_n_grid_desc
);
GridwiseGemm
::
MakeCGM0N0M1N1M2M3M4N2GridDescriptor
(
c_g_m_n_grid_desc
);
using
CGM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
using
CGM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_g_m_n_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_g_m_n_grid_desc
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_g_m_n_grid_desc
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_g_m_n_grid_desc
);
const
auto
kernel
=
kernel_gemm_xdlops_v3r1
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdlops_v3r1
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AGK0MK1GridDesc
>
,
remove_reference_t
<
AGK0MK1GridDesc
>
,
remove_reference_t
<
BGK0NK1GridDesc
>
,
remove_reference_t
<
BGK0NK1GridDesc
>
,
remove_reference_t
<
CGM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CGM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>>
;
remove_reference_t
<
CBlockClusterAdaptor
>>
;
/* #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = launch_and_time_kernel(kernel,
float
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat,
nrepeat
,
dim3(grid_size),
dim3
(
grid_size
),
dim3(BlockSize),
dim3
(
BlockSize
),
0,
0
,
p_a_grid,
p_a_grid
,
p_b_grid,
p_b_grid
,
p_c_grid,
p_c_grid
,
a_k0_m_k1_grid_desc,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc,
b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor);
c_block_cluster_adaptor
);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem
a_g_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
AGK0MK1GridDesc
));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
DeviceMem
b_g_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BGK0NK1GridDesc
));
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
DeviceMem
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
(
sizeof
(
CGM0N0M1N1M2M3M4N2GridDesc
));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
DeviceMem
c_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockClusterAdaptor
));
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
a_g_k0_m_k1_grid_desc_dev_buf
.
ToDevice
(
&
a_g_k0_m_k1_grid_desc
);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
b_g_k0_n_k1_grid_desc_dev_buf
.
ToDevice
(
&
b_g_k0_n_k1_grid_desc
);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
ToDevice
(
&
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
c_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_block_cluster_adaptor
);
float ave_time = launch_and_time_kernel(
float
ave_time
=
launch_and_time_kernel
(
kernel,
kernel
,
nrepeat,
nrepeat
,
dim3(grid_size),
dim3
(
grid_size
),
dim3(BlockSize),
dim3
(
BlockSize
),
0,
0
,
p_a_grid,
p_a_grid
,
p_b_grid,
p_b_grid
,
p_c_grid,
p_c_grid
,
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space
(
a_g_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space
(
b_g_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space(
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
#endif
#endif
return ave_time;*/
return
ave_time
;
return
0.0
;
}
}
#endif
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment