Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2a3a2f95
Commit
2a3a2f95
authored
May 05, 2023
by
Po-Yen, Chen
Browse files
Move descriptor creation logic into GridwiseGemm
parent
0bec80e5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
167 additions
and
46 deletions
+167
-46
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
...e/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
+34
-27
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+133
-19
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
View file @
2a3a2f95
...
@@ -184,10 +184,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -184,10 +184,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
}
}
}
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
BlockSize
,
...
@@ -195,12 +191,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -195,12 +191,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
A
GridDesc_K0_M_K1
,
A
Layout
,
B
GridDesc_K0_N_K1
,
B
Layout
,
C
GridDesc_M_N
,
C
Layout
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
GemmSpec
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
K0PerBlock
,
K0PerBlock
,
...
@@ -232,6 +229,10 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -232,6 +229,10 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
LoopSched
,
LoopSched
,
PipelineVer
>
;
PipelineVer
>
;
using
AGridDesc_K0_M_K1
=
decltype
(
GridwiseGemm
::
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
1
,
1
,
1
,
1
,
1
));
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -241,22 +242,28 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -241,22 +242,28 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t
M_
,
index_t
M_
,
index_t
N_
,
index_t
N_
,
index_t
K_
,
index_t
K_
,
index_t
StrideA
,
index_t
StrideA
_
,
index_t
StrideB
,
index_t
StrideB
_
,
index_t
StrideC
)
index_t
StrideC
_
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
M
{
M_
},
M
{
M_
},
N
{
N_
},
N
{
N_
},
K
{
K_
},
K
{
K_
},
a_grid_desc_k0_m_k1_
{},
StrideA
{
StrideA_
},
b_grid_desc_k0_n_k1_
{},
StrideB
{
StrideB_
},
c_grid_desc_m_n_
{}
StrideC
{
StrideC_
},
MPadded
{
GridwiseGemm
::
CalculateMPadded
(
M_
)},
NPadded
{
GridwiseGemm
::
CalculateNPadded
(
N_
)},
a_grid_desc_k0_m_k1
{},
b_grid_desc_k0_n_k1
{},
c_grid_desc_m_n
{}
{
{
a_grid_desc_k0_m_k1_
=
DeviceGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M_
,
K_
,
StrideA
);
a_grid_desc_k0_m_k1
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M_K1
(
M
,
MPadded
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K_
,
N_
,
StrideB
);
b_grid_desc_k0_n_k1
=
GridwiseGemm
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
NPadded
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemmXdl
::
MakeCGridDescriptor_M_N
(
M_
,
N_
,
StrideC
);
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
MPadded
,
N
,
NPadded
,
StrideC
);
}
}
// private:
// private:
...
@@ -266,9 +273,14 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -266,9 +273,14 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t
M
;
index_t
M
;
index_t
N
;
index_t
N
;
index_t
K
;
index_t
K
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
index_t
StrideA
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
index_t
StrideB
;
CGridDesc_M_N
c_grid_desc_m_n_
;
index_t
StrideC
;
index_t
MPadded
;
index_t
NPadded
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
;
CGridDesc_M_N
c_grid_desc_m_n
;
};
};
// Invoker
// Invoker
...
@@ -293,8 +305,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -293,8 +305,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
}
}
#endif
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
...
@@ -303,12 +314,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -303,12 +314,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
index_t
gdx
,
gdy
,
gdz
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
);
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
arg
.
K
))
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
Argument
,
true
>
;
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
Argument
,
true
>
;
...
@@ -382,8 +390,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -382,8 +390,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return
false
;
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
return
GridwiseGemm
::
CheckValidity
(
arg
);
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
2a3a2f95
...
@@ -44,12 +44,13 @@ template <index_t BlockSize,
...
@@ -44,12 +44,13 @@ template <index_t BlockSize,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC_
,
typename
FloatC_
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
A
GridDesc_K0_M_K1
,
typename
A
Layout
,
typename
B
GridDesc_K0_N_K1
,
typename
B
Layout
,
typename
C
GridDesc_M_N
,
typename
C
Layout
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
K0PerBlock
,
...
@@ -120,6 +121,111 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -120,6 +121,111 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
}
#undef INTEGER_DIVIDE_CEIL
#undef INTEGER_DIVIDE_CEIL
__host__
__device__
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
StrideA
)
{
const
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Value
)),
make_right_pad_transform
(
M
,
M
-
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
__host__
__device__
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
)
{
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Value
)),
make_right_pad_transform
(
N
,
N
-
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
M
-
MPad
),
make_right_pad_transform
(
N
,
N
-
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
...
@@ -196,10 +302,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -196,10 +302,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
__device__
static
constexpr
bool
template
<
typename
Argument
>
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
"wrong! K1 need to be known at compile-time"
);
...
@@ -208,13 +312,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -208,13 +312,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
(
NPerBlock
%
(
NXdlPerWave
*
NPerXDL
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M
=
karg
.
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
karg
.
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
karg
.
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
if
(
!
(
M
==
karg
.
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
karg
.
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K0
==
karg
.
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
K1
==
karg
.
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
karg
.
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
return
false
;
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
...
@@ -239,8 +344,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -239,8 +344,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
}
template
<
typename
CGridDesc
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc
_M_N
&
c_grid_desc_m_n
)
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc
&
c_grid_desc_m_n
)
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -291,8 +397,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -291,8 +397,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// return block_id to C matrix tile idx (m0, n0) mapping
// return block_id to C matrix tile idx (m0, n0) mapping
using
Block2CTileMap
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
>
;
using
Block2CTileMap
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
>
;
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
template
<
bool
HasMainKBlockLoop
,
typename
Argument
>
template
<
bool
HasMainKBlockLoop
,
typename
Argument
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
...
@@ -301,9 +405,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -301,9 +405,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
Argument
&
karg
)
const
Argument
&
karg
)
{
{
const
auto
a_grid_desc_k0_m_k1
=
karg
.
a_grid_desc_k0_m_k1_
;
#define CREATE_DESC_ON_HOST 1
const
auto
b_grid_desc_k0_n_k1
=
karg
.
b_grid_desc_k0_n_k1_
;
#if CREATE_DESC_ON_HOST
const
auto
c_grid_desc_m_n
=
karg
.
c_grid_desc_m_n_
;
const
auto
a_grid_desc_k0_m_k1
=
karg
.
a_grid_desc_k0_m_k1
;
const
auto
b_grid_desc_k0_n_k1
=
karg
.
b_grid_desc_k0_n_k1
;
const
auto
c_grid_desc_m_n
=
karg
.
c_grid_desc_m_n
;
#else
const
auto
a_grid_desc_k0_m_k1
=
MakeAGridDescriptor_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
);
const
auto
b_grid_desc_k0_n_k1
=
MakeBGridDescriptor_K0_N_K1
(
karg
.
K
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideB
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
MPadded
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideC
);
#endif
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment