Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
41a1466a
Commit
41a1466a
authored
Jul 27, 2023
by
Jing Zhang
Browse files
change m_loops to tile_loops
parent
36a527df
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
187 deletions
+42
-187
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
+3
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+39
-186
No files found.
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
View file @
41a1466a
...
@@ -55,7 +55,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
...
@@ -55,7 +55,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
2
,
0
,
1
,
3
>
,
S
<
2
,
0
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
2
,
0
,
1
,
3
>
,
S
<
2
,
0
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
64
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
32
,
1
>
,
S
<
2
,
0
,
1
,
3
>
,
S
<
2
,
0
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
2
,
0
,
1
,
3
>
,
S
<
2
,
0
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
// clang-format on
struct
ProblemSize
final
struct
ProblemSize
final
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
41a1466a
...
@@ -82,6 +82,8 @@ __global__ void
...
@@ -82,6 +82,8 @@ __global__ void
const
auto
local_b2e_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
,
KBatch
};
const
auto
local_b2e_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
,
KBatch
};
const
auto
local_grid_size
=
local_b2e_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
constexpr
auto
NumDTensor
=
DsDataType
::
Size
();
constexpr
auto
NumDTensor
=
DsDataType
::
Size
();
using
DsGridPointer
=
decltype
(
GridwiseGemm
::
MakeDsGridPointer
());
using
DsGridPointer
=
decltype
(
GridwiseGemm
::
MakeDsGridPointer
());
...
@@ -94,13 +96,12 @@ __global__ void
...
@@ -94,13 +96,12 @@ __global__ void
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_ds_grid
[
i
]);
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_ds_grid
[
i
]);
});
});
auto
m_loops
=
local_b2e_tile_map
.
CalculateMLoops
()
;
index_t
id_off
=
0
;
index_t
m_id
=
0
;
while
((
get_block_1d_id
()
-
BlockStart
+
id_off
)
<
local_grid_size
)
do
{
{
const
auto
block_2_etile_map
=
const
auto
block_2_etile_map
=
GroupedGemmBlock2ETileMap
(
local_b2e_tile_map
,
BlockStart
,
m_
id
);
GroupedGemmBlock2ETileMap
(
local_b2e_tile_map
,
BlockStart
,
id
_off
);
GridwiseGemm
::
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
>(
template
Run
<
HasMainKBlockLoop
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
>(
...
@@ -122,9 +123,8 @@ __global__ void
...
@@ -122,9 +123,8 @@ __global__ void
KBatch
,
KBatch
,
block_2_etile_map
);
block_2_etile_map
);
m_id
+=
1
;
id_off
+=
grid_size_grp
;
}
}
while
(
m_id
<
m_loops
);
#else
#else
ignore
=
grid_size_grp
;
ignore
=
grid_size_grp
;
ignore
=
gemm_descs_const
;
ignore
=
gemm_descs_const
;
...
@@ -201,82 +201,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -201,82 +201,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
static
auto
MakeAGridDescriptor_M_K
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
}();
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
static
auto
MakeBGridDescriptor_N_K
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideE
)
{
const
auto
e_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ELay
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideE
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELay
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
StrideE
));
}
}();
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
DsStride
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
MRaws
[
i
],
NRaws
[
i
],
DsStride
[
i
]);
},
Number
<
NumDTensor
>
{});
}
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
(
1
,
1
,
1
));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_splitk_cshuffle
<
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_splitk_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
...
@@ -321,40 +245,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -321,40 +245,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
LoopSched
>
;
#if 0
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
#endif
template
<
typename
UnderlyingBlockToCTileMap
>
template
<
typename
UnderlyingBlockToCTileMap
>
struct
OffsettedBlockToCTileMapMLoops
struct
OffsettedBlockToCTileMapMLoops
{
{
using
underlying_type
=
UnderlyingBlockToCTileMap
;
using
underlying_type
=
UnderlyingBlockToCTileMap
;
__host__
__device__
__host__
__device__
OffsettedBlockToCTileMapMLoops
(
OffsettedBlockToCTileMapMLoops
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
UnderlyingBlockToCTileMap
block_to_ctile_map
,
index_t
block_start
,
index_t
id_off
=
0
)
index_t
block_start
,
index_t
mblock_id_off
=
0
)
{
{
block_to_ctile_map_
=
block_to_ctile_map
;
block_to_ctile_map_
=
block_to_ctile_map
;
block_start_
=
block_start
;
block_start_
=
block_start
;
mblock_
id_off_
=
mblock_
id_off
;
id_off_
=
id_off
;
}
}
template
<
typename
TopIdx
>
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
{
auto
idx_bot
=
block_to_ctile_map_
.
CalculateBottomIndex
(
auto
idx_bot
=
block_to_ctile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
Number
<
0
>
{}]
-
block_start_
));
make_multi_index
(
idx_top
[
Number
<
0
>
{}]
-
block_start_
+
id_off_
));
return
make_tuple
(
return
make_tuple
(
idx_bot
[
Number
<
0
>
{}],
idx_bot
[
Number
<
1
>
{}],
idx_bot
[
Number
<
2
>
{}]);
idx_bot
[
Number
<
0
>
{}],
idx_bot
[
Number
<
1
>
{}]
+
mblock_id_off_
,
idx_bot
[
Number
<
2
>
{}]);
}
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
template
<
typename
CTileIdx
,
typename
CTileDim
>
...
@@ -378,7 +288,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -378,7 +288,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
block_start_
;
index_t
block_start_
;
index_t
mblock_
id_off_
;
index_t
id_off_
;
};
};
template
<
index_t
MPerBlock_
,
index_t
NPerBlock_
>
template
<
index_t
MPerBlock_
,
index_t
NPerBlock_
>
...
@@ -414,21 +324,17 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -414,21 +324,17 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
{
{
}
}
__host__
__device__
constexpr
index_t
CalculateMLoops
()
const
__host__
__device__
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
const
{
return
math
::
integer_divide_ceil
(
M_
,
MPerBlock_
);
}
__host__
constexpr
index_t
CalculateGridSize
(
index_t
/*M*/
,
index_t
N
)
const
{
{
const
auto
M0
=
1
;
//
math::integer_divide_ceil(M, MPerBlock);
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
M0
*
N0
*
KBatch_
;
return
M0
*
N0
*
KBatch_
;
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
__device__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
}
...
@@ -444,7 +350,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -444,7 +350,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
{
{
auto
block_1d_id
=
idx_top
[
I0
];
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
1
;
//
math::integer_divide_ceil(M_, MPerBlock_);
const
auto
M0
=
math
::
integer_divide_ceil
(
M_
,
MPerBlock_
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N_
,
NPerBlock_
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N_
,
NPerBlock_
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
*
KBatch_
);
// hide groups
block_1d_id
=
block_1d_id
%
(
M0
*
N0
*
KBatch_
);
// hide groups
...
@@ -495,24 +401,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -495,24 +401,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
index_t
StrideA_
,
StrideB_
;
index_t
StrideA_
,
StrideB_
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
;
index_t
StrideE_
;
index_t
StrideE_
;
#if 0
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
#endif
};
};
// Argument
// Argument
...
@@ -561,13 +449,19 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -561,13 +449,19 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
index_t
group_id
=
0
;
index_t
group_id
=
0
;
sum_of_m
=
gemm_descs
[
0
].
M_
;
const
index_t
AverM
=
sum_of_m
/
group_count_
;
const
index_t
N
=
gemm_descs
[
0
].
N_
;
const
index_t
K
=
gemm_descs
[
0
].
K_
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
{
const
index_t
M
=
gemm_descs
[
i
].
M_
;
if
(
sum_of_m
!=
gemm_descs
[
i
].
M_
||
N
!=
gemm_descs
[
i
].
N_
||
K
!=
gemm_descs
[
i
].
K_
)
const
index_t
N
=
gemm_descs
[
i
].
N_
;
{
const
index_t
K
=
gemm_descs
[
i
].
K_
;
throw
std
::
runtime_error
(
"wrong! M/N/K is not identical"
);
}
a_mtx_mraw_kraw_
.
emplace_back
(
M
,
K
);
a_mtx_mraw_kraw_
.
emplace_back
(
sum_of_m
,
K
);
b_mtx_nraw_kraw_
.
emplace_back
(
N
,
K
);
b_mtx_nraw_kraw_
.
emplace_back
(
N
,
K
);
const
index_t
StrideA
=
gemm_descs
[
i
].
stride_A_
;
const
index_t
StrideA
=
gemm_descs
[
i
].
stride_A_
;
...
@@ -584,12 +478,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -584,12 +478,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static_cast
<
const
DDataType
*>
(
p_Ds
.
size
()
==
0
?
nullptr
:
p_Ds
[
i
][
j
]);
static_cast
<
const
DDataType
*>
(
p_Ds
.
size
()
==
0
?
nullptr
:
p_Ds
[
i
][
j
]);
});
});
// tensor descriptors for problem definiton
// const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(M, K, StrideA);
// const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(K, N, StrideB);
// DsGridDesc_M_N ds_grid_desc_m_n;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
...
@@ -602,27 +490,20 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -602,27 +490,20 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
}
}
StrideDs
[
j
]
=
gemm_descs
[
i
].
stride_Ds_
[
j
];
StrideDs
[
j
]
=
gemm_descs
[
i
].
stride_Ds_
[
j
];
// ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
// M, N, gemm_descs[i].stride_Ds_[j]);
});
});
#if 0
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
#endif
const
auto
e_grid_desc_m_n
=
const
auto
e_grid_desc_m_n
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
StrideE
);
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
AverM
,
N
,
StrideE
);
// block-to-e-tile map
// block-to-e-tile map
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
,
k_batch
};
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
,
k_batch
};
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
// std::cout << "group_id: " << group_id << " grid_size_grp: " << grid_size_grp
//<< std::endl;
if
(
group_id
*
grid_size_grp
!=
grid_size_
)
if
(
group_id
*
grid_size_grp
!=
grid_size_
)
{
{
throw
std
::
runtime_error
(
"wrong! grid_size_grp is not identical!"
);
throw
std
::
runtime_error
(
"wrong! grid_size_grp is not identical!"
);
...
@@ -638,7 +519,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -638,7 +519,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
if
(
!
GridwiseGemm
::
if
(
!
GridwiseGemm
::
template
CheckValidity
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
GemmSpec
>(
template
CheckValidity
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
GemmSpec
>(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
1
))
Aver
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
1
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
...
@@ -649,7 +530,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -649,7 +530,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
p_Bs
.
size
()
==
0
?
nullptr
:
p_Bs
[
i
],
p_Bs
.
size
()
==
0
?
nullptr
:
p_Bs
[
i
],
p_ds_grid
,
p_ds_grid
,
p_Es
[
i
],
p_Es
[
i
],
M
,
Aver
M
,
N
,
N
,
K
,
K
,
StrideA
,
StrideA
,
...
@@ -677,6 +558,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -677,6 +558,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
index_t
grid_size_
;
index_t
grid_size_
;
index_t
grid_size_grp
;
index_t
grid_size_grp
;
index_t
sum_of_m
;
};
};
// Invoker
// Invoker
...
@@ -735,38 +617,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -735,38 +617,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CDEElementwiseOperation
,
CDEElementwiseOperation
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
>
;
const
void
*
kernel_args_dev
=
nullptr
;
if
(
arg
.
grouped_gemm_kernel_args_dev
==
nullptr
)
if
(
arg
.
grouped_gemm_kernel_args_dev
!=
nullptr
)
{
kernel_args_dev
=
arg
.
grouped_gemm_kernel_args_dev
;
}
else
{
{
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
throw
std
::
runtime_error
(
"wrong! grouped_gemm_kernel_args_dev is nullpr"
);
{
if
(
arg
.
gemm_desc_kernel_arg_
[
i
].
a_ptr_
==
nullptr
||
arg
.
gemm_desc_kernel_arg_
[
i
].
b_ptr_
==
nullptr
||
arg
.
gemm_desc_kernel_arg_
[
i
].
e_ptr_
==
nullptr
)
{
throw
std
::
runtime_error
(
"wrong! p_a/b/c_grid is nullptr"
);
}
}
if
(
arg
.
p_workspace_
==
nullptr
)
{
throw
std
::
runtime_error
(
"wrong! arg.p_workspace_ == nullptr"
);
}
hipGetErrorString
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
grouped_gemm_kernel_args
.
data
(),
grouped_gemm_kernel_args
.
size
()
*
sizeof
(
GroupedGemmKernelArgument
<
NumDTensor
>
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
kernel_args_dev
=
arg
.
p_workspace_
;
}
}
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
...
@@ -775,7 +628,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -775,7 +628,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
dim3
(
arg
.
grid_size_
),
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
kernel_args_dev
),
cast_pointer_to_constant_address_space
(
arg
.
grouped_gemm_
kernel_args_dev
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
grid_size_grp
,
arg
.
grid_size_grp
,
k_batch
,
k_batch
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment