Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
62fdce6d
"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "138fac703a8359460a59d20082114195759722b8"
Commit
62fdce6d
authored
Sep 07, 2021
by
ltqin
Browse files
revome a b matrix k0mk1 desc in grid
parent
b7c1259f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
54 deletions
+28
-54
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
+18
-34
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nkhw.hpp
...ard_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nkhw.hpp
+10
-10
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
+0
-10
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
View file @
62fdce6d
...
@@ -16,8 +16,6 @@ namespace ck {
...
@@ -16,8 +16,6 @@ namespace ck {
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
typename
AK0MK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
ABK0MK1GridDesc
,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
...
@@ -29,10 +27,8 @@ __global__ void
...
@@ -29,10 +27,8 @@ __global__ void
kernel_gemm_xdlops_v2r4
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_xdlops_v2r4
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AK0MK1GridDesc
a_k0_m_k1_grid_desc
,
const
void
ABK0MK1GridDesc
a_b_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
b_k0_n_k1_grid_desc
,
const
void
BBK0NK1GridDesc
b_b_k0_n_k1_grid_desc
,
const
void
CONSTANT
*
a_b_k0_m_k1_grid_desc
,
const
void
CONSTANT
*
b_b_k0_n_k1_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
c_m0_m1_m2_n_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
c_m0_m1_m2_n_grid_desc
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
{
...
@@ -45,8 +41,6 @@ __global__ void
...
@@ -45,8 +41,6 @@ __global__ void
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared_block
,
p_shared_block
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
a_b_k0_m_k1_grid_desc
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
...
@@ -56,8 +50,6 @@ __global__ void
...
@@ -56,8 +50,6 @@ __global__ void
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
typename
AK0MK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
ABK0MK1GridDesc
,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
...
@@ -69,8 +61,6 @@ __global__ void
...
@@ -69,8 +61,6 @@ __global__ void
kernel_gemm_xdlops_v2r4
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_xdlops_v2r4
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_k0_m_k1_grid_desc
,
const
void
CONSTANT
*
p_b_k0_n_k1_grid_desc
,
const
void
CONSTANT
*
p_a_b_k0_m_k1_grid_desc
,
const
void
CONSTANT
*
p_a_b_k0_m_k1_grid_desc
,
const
void
CONSTANT
*
p_b_b_k0_n_k1_grid_desc
,
const
void
CONSTANT
*
p_b_b_k0_n_k1_grid_desc
,
const
void
CONSTANT
*
p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
void
CONSTANT
*
p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
...
@@ -79,10 +69,6 @@ __global__ void
...
@@ -79,10 +69,6 @@ __global__ void
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
const
auto
a_k0_m_k1_grid_desc
=
*
reinterpret_cast
<
const
AK0MK1GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_a_k0_m_k1_grid_desc
));
const
auto
b_k0_n_k1_grid_desc
=
*
reinterpret_cast
<
const
BK0NK1GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_b_k0_n_k1_grid_desc
));
const
auto
a_b_k0_m_k1_grid_desc
=
*
reinterpret_cast
<
const
ABK0MK1GridDesc
*>
(
const
auto
a_b_k0_m_k1_grid_desc
=
*
reinterpret_cast
<
const
ABK0MK1GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_a_b_k0_m_k1_grid_desc
));
cast_pointer_to_generic_address_space
(
p_a_b_k0_m_k1_grid_desc
));
const
auto
b_b_k0_n_k1_grid_desc
=
*
reinterpret_cast
<
const
BBK0NK1GridDesc
*>
(
const
auto
b_b_k0_n_k1_grid_desc
=
*
reinterpret_cast
<
const
BBK0NK1GridDesc
*>
(
...
@@ -99,8 +85,6 @@ __global__ void
...
@@ -99,8 +85,6 @@ __global__ void
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared_block
,
p_shared_block
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
a_b_k0_m_k1_grid_desc
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
...
@@ -348,8 +332,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -348,8 +332,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
FloatAB
*
__restrict__
p_shared_block
,
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
...
@@ -362,13 +344,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -362,13 +344,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetElementSpaceSize
());
p_c_grid
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetElementSpaceSize
());
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
M
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I2
);
const
auto
M
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I2
);
const
auto
N
=
b_b_k0_n_k1_grid_desc
.
GetLength
(
I2
);
const
auto
N
=
b_b_k0_n_k1_grid_desc
.
GetLength
(
I2
);
const
auto
b_grid_size
=
CalculateGridSize
(
M
,
N
);
const
auto
b_grid_size
=
CalculateGridSize
(
M
,
N
);
const
auto
k_batch_id
=
get_block_1d_id
()
/
b_grid_size
;
const
auto
k_batch_id
=
get_block_1d_id
()
/
b_grid_size
;
const
auto
block_id_in_batch
=
get_block_1d_id
()
%
b_grid_size
;
const
auto
block_id_in_batch
=
get_block_1d_id
()
%
b_grid_size
;
if
(
get_block_1d_id
()
==
2000
)
if
(
get_block_1d_id
()
==
2000
00
)
printf
(
"grid size: %d, k0: %d, blockid: %d, threadid %d, Batch: %d block_id: %d
\n
"
,
printf
(
"grid size: %d, k0: %d, blockid: %d, threadid %d, Batch: %d block_id: %d
\n
"
,
b_grid_size
,
b_grid_size
,
K0
,
K0
,
...
@@ -426,10 +408,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -426,10 +408,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_b_k0_m_k1_grid_desc
,
true
>
(
make_multi_index
(
k_batch_id
,
0
,
m_block_data_idx_on_grid
,
0
),
a_b_k0_m_k1_grid_desc
,
a_b_k0_m_k1_block_desc
,
make_multi_index
(
k_batch_id
,
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
));
a_b_k0_m_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
...
@@ -452,10 +435,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -452,10 +435,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_b_k0_n_k1_grid_desc
,
true
>
(
make_multi_index
(
k_batch_id
,
0
,
n_block_data_idx_on_grid
,
0
),
b_b_k0_n_k1_grid_desc
,
b_b_k0_n_k1_block_desc
,
make_multi_index
(
k_batch_id
,
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
));
b_b_k0_n_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
...
...
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nkhw.hpp
View file @
62fdce6d
...
@@ -76,7 +76,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
...
@@ -76,7 +76,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
KBatch
=
2
;
constexpr
index_t
KBatch
=
64
;
#elif 1
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -132,15 +132,15 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
...
@@ -132,15 +132,15 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_step_hacks
=
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmB
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmB
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
View file @
62fdce6d
...
@@ -156,8 +156,6 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -156,8 +156,6 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
ABK0MK1GridDesc
>
,
remove_reference_t
<
ABK0MK1GridDesc
>
,
remove_reference_t
<
BBK0NK1GridDesc
>
,
remove_reference_t
<
BBK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
...
@@ -172,23 +170,17 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -172,23 +170,17 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
a_b_k0_m_k1_grid_desc
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
c_block_cluster_adaptor
);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
AK0MK1GridDesc
));
DeviceMem
b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BK0NK1GridDesc
));
DeviceMem
a_b_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
ABK0MK1GridDesc
));
DeviceMem
a_b_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
ABK0MK1GridDesc
));
DeviceMem
b_b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BBK0NK1GridDesc
));
DeviceMem
b_b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BBK0NK1GridDesc
));
DeviceMem
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
(
sizeof
(
CM0N0M1N1M2M3M4N2GridDesc
));
DeviceMem
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
(
sizeof
(
CM0N0M1N1M2M3M4N2GridDesc
));
DeviceMem
c_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockClusterAdaptor
));
DeviceMem
c_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockClusterAdaptor
));
a_k0_m_k1_grid_desc_dev_buf
.
ToDevice
(
&
a_k0_m_k1_grid_desc
);
b_k0_n_k1_grid_desc_dev_buf
.
ToDevice
(
&
b_k0_n_k1_grid_desc
);
a_b_k0_m_k1_grid_desc_dev_buf
.
ToDevice
(
&
a_b_k0_m_k1_grid_desc
);
a_b_k0_m_k1_grid_desc_dev_buf
.
ToDevice
(
&
a_b_k0_m_k1_grid_desc
);
b_b_k0_n_k1_grid_desc_dev_buf
.
ToDevice
(
&
b_b_k0_n_k1_grid_desc
);
b_b_k0_n_k1_grid_desc_dev_buf
.
ToDevice
(
&
b_b_k0_n_k1_grid_desc
);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
...
@@ -203,8 +195,6 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -203,8 +195,6 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
a_b_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
a_b_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment