Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bbe5c0c7
Commit
bbe5c0c7
authored
Mar 05, 2022
by
Jing Zhang
Browse files
2 gemm test
parent
6cbb0a13
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
182 additions
and
95 deletions
+182
-95
composable_kernel/include/config.hpp
composable_kernel/include/config.hpp
+1
-0
composable_kernel/include/tensor_operation/gridwise_grouped_gemm_xdlops_v2r3.hpp
...de/tensor_operation/gridwise_grouped_gemm_xdlops_v2r3.hpp
+72
-16
device_operation/include/device_grouped_gemm_xdl.hpp
device_operation/include/device_grouped_gemm_xdl.hpp
+104
-74
example/12_grouped_gemm_xdl/grouped_gemm_xdl.cpp
example/12_grouped_gemm_xdl/grouped_gemm_xdl.cpp
+5
-5
No files found.
composable_kernel/include/config.hpp
View file @
bbe5c0c7
...
...
@@ -176,6 +176,7 @@ struct gemm_desc
ck
::
index_t
M
,
N
,
K
;
ck
::
index_t
StrideA
,
StrideB
,
StrideC
;
ck
::
index_t
OffsetA
,
OffsetB
,
OffsetC
;
ck
::
index_t
BlockStart
,
BlockSize
;
};
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/gridwise_grouped_gemm_xdlops_v2r3.hpp
View file @
bbe5c0c7
...
...
@@ -18,11 +18,13 @@ template <typename GridwiseGemm,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
GemmDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
,
bool
HasMainK0BlockLoop
>
bool
HasMainK0BlockLoop
,
index_t
MaxGroupCount
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -34,6 +36,8 @@ __global__ void
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
GemmDesc
gemm_shapes
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
...
...
@@ -43,24 +47,67 @@ __global__ void
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
group_id
=
0
;
index_t
group_id
=
0
;
index_t
block_id_grp
=
0
;
index_t
a_offset_grp
=
0
;
index_t
b_offset_grp
=
0
;
index_t
c_offset_grp
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
i
<
group_count
)
{
if
(
block_id
>=
gemm_shapes
[
i
].
BlockStart
&&
block_id
<
(
gemm_shapes
[
i
].
BlockStart
+
gemm_shapes
[
i
].
BlockSize
))
{
group_id
=
i
;
block_id_grp
=
block_id
-
gemm_shapes
[
i
].
BlockStart
;
a_offset_grp
=
gemm_shapes
[
i
].
OffsetA
;
b_offset_grp
=
gemm_shapes
[
i
].
OffsetB
;
c_offset_grp
=
gemm_shapes
[
i
].
OffsetC
;
// if(get_thread_local_1d_id() == 0)
// printf("%d %d %d %d %d %d\n",
// block_id,
// group_id,
// block_id_grp,
// a_offset_grp,
// b_offset_grp,
// c_offset_grp);
}
}
});
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
if
(
group_id
==
0
)
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
a_grid_desc_k0_m_k1
[
Number
<
0
>
{}],
b_grid_desc_k0_n_k1
[
Number
<
0
>
{}],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
[
Number
<
0
>
{}],
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
[
Number
<
0
>
{}],
block_id
);
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
+
a_offset_grp
,
p_b_grid
+
b_offset_grp
,
p_c_grid
+
c_offset_grp
,
p_shared
,
a_grid_desc_k0_m_k1
[
I0
],
b_grid_desc_k0_n_k1
[
I0
],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
[
I0
],
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
[
I0
],
block_id_grp
,
group_id
);
else
if
(
group_id
==
1
)
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
+
a_offset_grp
,
p_b_grid
+
b_offset_grp
,
p_c_grid
+
c_offset_grp
,
p_shared
,
a_grid_desc_k0_m_k1
[
I1
],
b_grid_desc_k0_n_k1
[
I1
],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
[
I1
],
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
[
I1
],
block_id_grp
,
group_id
);
}
template
<
index_t
BlockSize
,
...
...
@@ -360,7 +407,8 @@ struct GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
index_t
block_id
)
const
index_t
block_id
,
const
index_t
group_id
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
...
...
@@ -382,6 +430,14 @@ struct GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// if(get_thread_local_1d_id() == 0)
//{
// printf("m: %d n: %d k: %d\n", a_grid_desc_k0_m_k1.GetLength(I1),
// b_grid_desc_k0_n_k1.GetLength(I1), a_grid_desc_k0_m_k1.GetLength(I0));
// printf("block_work_idx: %d %d %d %d\n", group_id, block_id, block_work_idx[I0],
// block_work_idx[I1]);
//}
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
...
...
device_operation/include/device_grouped_gemm_xdl.hpp
View file @
bbe5c0c7
...
...
@@ -53,8 +53,8 @@ template <typename ADataType,
bool
BBlockLdsAddExtraN
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
ck
::
index_t
NumPrefetch
=
1
,
ck
::
index_t
GroupCount
=
1
>
ck
::
index_t
NumPrefetch
=
1
,
ck
::
index_t
Max
GroupCount
=
5
>
struct
DeviceGroupedGemmXdl
:
public
DeviceGroupedGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
...
...
@@ -238,55 +238,62 @@ struct DeviceGroupedGemmXdl
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
gemm_shapes_
{
gemm_shapes
},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
const
index_t
i
=
0
;
const
index_t
M
=
gemm_shapes
[
Number
<
0
>
{}].
M
;
const
index_t
N
=
gemm_shapes
[
Number
<
0
>
{}].
N
;
const
index_t
K
=
gemm_shapes
[
Number
<
0
>
{}].
K
;
const
index_t
StrideA
=
gemm_shapes
[
Number
<
0
>
{}].
StrideA
;
const
index_t
StrideB
=
gemm_shapes
[
Number
<
0
>
{}].
StrideB
;
const
index_t
StrideC
=
gemm_shapes
[
Number
<
0
>
{}].
StrideC
;
a_grid_desc_k0_m_k1_
(
Number
<
0
>
{})
=
DeviceGroupedGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
(
Number
<
0
>
{})
=
DeviceGroupedGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
(
Number
<
0
>
{})
=
DeviceGroupedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
[
Number
<
0
>
{}],
b_grid_desc_k0_n_k1_
[
Number
<
0
>
{}],
c_grid_desc_m_n_
[
Number
<
0
>
{}],
M01_
,
N01_
))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
(
Number
<
0
>
{})
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
[
Number
<
0
>
{}]);
block_2_ctile_map_
(
Number
<
0
>
{})
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
[
Number
<
0
>
{}],
M01
,
N01
);
}
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
i
<
gemm_shapes_
.
size
())
{
const
index_t
M
=
gemm_shapes_
[
i
].
M
;
const
index_t
N
=
gemm_shapes_
[
i
].
N
;
const
index_t
K
=
gemm_shapes_
[
i
].
K
;
const
index_t
StrideA
=
gemm_shapes_
[
i
].
StrideA
;
const
index_t
StrideB
=
gemm_shapes_
[
i
].
StrideB
;
const
index_t
StrideC
=
gemm_shapes_
[
i
].
StrideC
;
a_grid_desc_k0_m_k1_
(
i
)
=
DeviceGroupedGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
(
i
)
=
DeviceGroupedGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
(
i
)
=
DeviceGroupedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
[
i
],
b_grid_desc_k0_n_k1_
[
i
],
c_grid_desc_m_n_
[
i
],
M01_
,
N01_
))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
(
i
)
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
[
i
]);
block_2_ctile_map_
(
i
)
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
[
i
],
M01
,
N01
);
}
}
});
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
StaticallyIndexedArray
<
AGridDesc_K0_M_K1
,
GroupCount
>
a_grid_desc_k0_m_k1_
;
StaticallyIndexedArray
<
BGridDesc_K0_N_K1
,
GroupCount
>
b_grid_desc_k0_n_k1_
;
StaticallyIndexedArray
<
CGridDesc_M_N
,
GroupCount
>
c_grid_desc_m_n_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
GroupCount
>
StaticallyIndexedArray
<
AGridDesc_K0_M_K1
,
MaxGroupCount
>
a_grid_desc_k0_m_k1_
;
StaticallyIndexedArray
<
BGridDesc_K0_N_K1
,
MaxGroupCount
>
b_grid_desc_k0_n_k1_
;
StaticallyIndexedArray
<
CGridDesc_M_N
,
MaxGroupCount
>
c_grid_desc_m_n_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
MaxGroupCount
>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
GroupCount
>
StaticallyIndexedArray
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
Max
GroupCount
>
block_2_ctile_map_
;
std
::
vector
<
gemm_desc
>
gemm_shapes_
;
index_t
M01_
;
index_t
N01_
;
AElementwiseOperation
a_element_op_
;
...
...
@@ -301,33 +308,49 @@ struct DeviceGroupedGemmXdl
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
std
::
cout
<<
"Gemm "
<<
i
<<
":"
<<
std
::
endl
;
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
[
Number
<
0
>
{}].
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
[
Number
<
0
>
{}].
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
[
Number
<
0
>
{}].
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
[
Number
<
0
>
{}].
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
[
Number
<
0
>
{}].
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
[
Number
<
0
>
{}].
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
[
Number
<
0
>
{}].
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
[
Number
<
0
>
{}].
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
[
Number
<
0
>
{}],
arg
.
b_grid_desc_k0_n_k1_
[
Number
<
0
>
{}],
arg
.
c_grid_desc_m_n_
[
Number
<
0
>
{}],
arg
.
M01_
,
arg
.
N01_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
[
Number
<
0
>
{}]);
StaticallyIndexedArray
<
gemm_desc
,
MaxGroupCount
>
gemm_shapes
;
index_t
grid_size
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
i
<
arg
.
gemm_shapes_
.
size
())
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
[
i
].
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
[
i
].
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
[
i
].
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
[
i
],
arg
.
b_grid_desc_k0_n_k1_
[
i
],
arg
.
c_grid_desc_m_n_
[
i
],
arg
.
M01_
,
arg
.
N01_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
const
index_t
grid_size_grp
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
[
i
]);
gemm_shapes
(
i
)
=
arg
.
gemm_shapes_
[
i
];
gemm_shapes
(
i
).
BlockStart
=
grid_size
;
gemm_shapes
(
i
).
BlockSize
=
grid_size_grp
;
grid_size
+=
grid_size_grp
;
std
::
cout
<<
"group_id "
<<
i
<<
" BlockStart "
<<
gemm_shapes
(
i
).
BlockStart
<<
" BlockSize "
<<
gemm_shapes
(
i
).
BlockSize
<<
std
::
endl
;
}
});
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
[
Number
<
0
>
{}].
GetLength
(
I0
);
...
...
@@ -344,20 +367,22 @@ struct DeviceGroupedGemmXdl
CDataType
,
remove_reference_t
<
StaticallyIndexedArray
<
DeviceGroupedGemmXdl
::
AGridDesc_K0_M_K1
,
GroupCount
>>
,
Max
GroupCount
>>
,
remove_reference_t
<
StaticallyIndexedArray
<
DeviceGroupedGemmXdl
::
BGridDesc_K0_N_K1
,
GroupCount
>>
,
Max
GroupCount
>>
,
remove_reference_t
<
StaticallyIndexedArray
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
GroupCount
>>
,
MaxGroupCount
>>
,
remove_reference_t
<
StaticallyIndexedArray
<
gemm_desc
,
MaxGroupCount
>>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
StaticallyIndexedArray
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
GroupCount
>>
,
true
>
;
MaxGroupCount
>>
,
true
,
MaxGroupCount
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
...
...
@@ -370,6 +395,8 @@ struct DeviceGroupedGemmXdl
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
gemm_shapes
,
arg
.
gemm_shapes_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
...
...
@@ -383,20 +410,22 @@ struct DeviceGroupedGemmXdl
CDataType
,
remove_reference_t
<
StaticallyIndexedArray
<
DeviceGroupedGemmXdl
::
AGridDesc_K0_M_K1
,
GroupCount
>>
,
Max
GroupCount
>>
,
remove_reference_t
<
StaticallyIndexedArray
<
DeviceGroupedGemmXdl
::
BGridDesc_K0_N_K1
,
GroupCount
>>
,
Max
GroupCount
>>
,
remove_reference_t
<
StaticallyIndexedArray
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
GroupCount
>>
,
MaxGroupCount
>>
,
remove_reference_t
<
StaticallyIndexedArray
<
gemm_desc
,
MaxGroupCount
>>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
StaticallyIndexedArray
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
GroupCount
>>
,
false
>
;
MaxGroupCount
>>
,
false
,
MaxGroupCount
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
...
...
@@ -409,6 +438,8 @@ struct DeviceGroupedGemmXdl
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
gemm_shapes
,
arg
.
gemm_shapes_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
...
...
@@ -434,7 +465,6 @@ struct DeviceGroupedGemmXdl
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
const
index_t
i
=
0
;
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
[
Number
<
0
>
{}],
arg
.
b_grid_desc_k0_n_k1_
[
Number
<
0
>
{}],
arg
.
c_grid_desc_m_n_
[
Number
<
0
>
{}],
...
...
example/12_grouped_gemm_xdl/grouped_gemm_xdl.cpp
View file @
bbe5c0c7
...
...
@@ -76,7 +76,7 @@ int main(int argc, char* argv[])
exit
(
0
);
}
int
group_count
=
1
;
int
group_count
=
2
;
// GEMM shape
std
::
vector
<
ck
::
gemm_desc
>
gemm_shapes
;
...
...
@@ -85,11 +85,11 @@ int main(int argc, char* argv[])
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
256
;
int
N
=
512
;
int
K
=
1024
;
int
M
=
256
*
(
i
+
1
)
;
int
N
=
512
*
(
i
+
1
)
;
int
K
=
1024
*
(
i
+
1
)
;
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
,
A_size
,
B_size
,
C_size
});
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
,
A_size
,
B_size
,
C_size
,
0
,
0
});
A_size
+=
M
*
K
;
B_size
+=
N
*
K
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment