Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dfbb659a
Commit
dfbb659a
authored
Jul 26, 2022
by
Chao Liu
Browse files
upate contraction example
parent
809799bf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
128 additions
and
334 deletions
+128
-334
include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp
...or_operation/gpu/device/device_contraction_multiple_d.hpp
+8
-8
include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp
...gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp
+119
-323
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
+1
-3
No files found.
include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp
View file @
dfbb659a
...
@@ -43,14 +43,14 @@ struct DeviceContractionMultipleD : public BaseOperator
...
@@ -43,14 +43,14 @@ struct DeviceContractionMultipleD : public BaseOperator
const
void
*
p_b
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
void
*
p_e
,
std
::
vector
<
index_t
>
a_ms_
k
s_lengths
,
const
std
::
vector
<
index_t
>
&
a_ms_
n
s_lengths
,
std
::
vector
<
index_t
>
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>
&
a_ms_ks_strides
,
std
::
vector
<
index_t
>
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>
&
b_ns_ks_lengths
,
std
::
vector
<
index_t
>
b_ns_ks_strides
,
const
std
::
vector
<
index_t
>
&
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
&
ds_ms_ns_lengths
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
&
ds_ms_ns_strides
,
std
::
vector
<
index_t
>
e_ms_ns_lengths
,
const
std
::
vector
<
index_t
>
&
e_ms_ns_lengths
,
std
::
vector
<
index_t
>
e_ms_ns_strides
,
const
std
::
vector
<
index_t
>
&
e_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
CDEElementwiseOperation
cde_element_op
)
=
0
;
...
...
include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
dfbb659a
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -106,7 +107,7 @@ template <index_t NumDimM,
...
@@ -106,7 +107,7 @@ template <index_t NumDimM,
index_t
NumDimK
,
index_t
NumDimK
,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
Gemm
AccDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
EDataType
,
...
@@ -165,9 +166,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -165,9 +166,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
// Assume: A[M0, M1, M2, ..., K0, K1, K2, ...]
// Assume: A[M0, M1, M2, ..., K0, K1, K2, ...]
static
auto
MakeAGridDescriptor_
AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_ms_ks_lengths_vec
,
static
auto
MakeAGridDescriptor_
M_K
(
const
std
::
vector
<
index_t
>&
a_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides_vec
)
const
std
::
vector
<
index_t
>&
a_ms_ks_strides_vec
)
{
{
assert
(
a_ms_ks_lengths_vec
.
size
()
==
NumDimM
+
NumDimK
&&
assert
(
a_ms_ks_lengths_vec
.
size
()
==
NumDimM
+
NumDimK
&&
a_ms_ks_strides_vec
.
size
()
==
NumDimM
+
NumDimK
);
a_ms_ks_strides_vec
.
size
()
==
NumDimM
+
NumDimK
);
...
@@ -203,100 +207,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -203,100 +207,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
make_tuple
(
mDimIds
,
kDimIds
),
make_tuple
(
mDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
MRaw
=
a_grid_desc_mraw_kraw
.
GetLength
(
I0
);
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
const
auto
KRaw
=
a_grid_desc_mraw_kraw
.
GetLength
(
I1
);
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
}
// Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
// Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
static
auto
MakeBGridDescriptor_
BK0_N_BK1
(
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths_vec
,
static
auto
MakeBGridDescriptor_
N_K
(
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides_vec
)
const
std
::
vector
<
index_t
>&
b_ns_ks_strides_vec
)
{
{
assert
(
b_ns_ks_lengths_vec
.
size
()
==
NumDimN
+
NumDimK
&&
assert
(
b_ns_ks_lengths_vec
.
size
()
==
NumDimN
+
NumDimK
&&
b_ns_ks_strides_vec
.
size
()
==
NumDimN
+
NumDimK
);
b_ns_ks_strides_vec
.
size
()
==
NumDimN
+
NumDimK
);
...
@@ -332,95 +248,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -332,95 +248,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
make_tuple
(
nDimIds
,
kDimIds
),
make_tuple
(
nDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
NRaw
=
b_grid_desc_nraw_kraw
.
GetLength
(
I0
);
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
const
auto
KRaw
=
b_grid_desc_nraw_kraw
.
GetLength
(
I1
);
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
}
// assume E[M0, M1, M2, ..., N0, N1, N2...]
// assume E[M0, M1, M2, ..., N0, N1, N2...]
...
@@ -461,63 +289,30 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -461,63 +289,30 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
make_tuple
(
mDimIds
,
nDimIds
),
make_tuple
(
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
MRaw
=
e_grid_desc_mraw_nraw
.
GetLength
(
I0
);
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
const
auto
NRaw
=
e_grid_desc_mraw_nraw
.
GetLength
(
I1
);
}
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
static
auto
MakeDsGridDescriptor_M_N
(
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_lengths_vec
,
{
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_strides_vec
)
// pad M and N
{
return
transform_tensor_descriptor
(
e_grid_desc_mraw_nraw
,
return
generate_tuple
(
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
[
&
](
auto
i
)
{
make_right_pad_transform
(
NRaw
,
NPad
)),
return
DeviceOp
::
MakeEGridDescriptor_M_N
(
ds_ms_ns_lengths_vec
[
i
],
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
ds_ms_ns_strides_vec
[
i
]);
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
},
}
Number
<
NumDTensor
>
{});
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
e_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
e_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
e_grid_desc_mraw_nraw
;
}
}
}
using
AGridDesc_AK0_M_AK1
=
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
({},
{}));
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
std
::
vector
<
index_t
>
{},
std
::
vector
<
index_t
>
{}));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
({},
{}));
using
BGridDesc_BK0_N_BK1
=
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({{}},
{{}}))
>
;
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
std
::
vector
<
index_t
>
{},
std
::
vector
<
index_t
>
{}));
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
({},
{}));
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
(
std
::
vector
<
index_t
>
{},
std
::
vector
<
index_t
>
{}));
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_
k0mk1_k0nk1_mn_
xdl_cshuffle
<
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
Gemm
AccDataType
,
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
DsDataType
,
DsDataType
,
EDataType
,
EDataType
,
...
@@ -525,8 +320,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -525,8 +320,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
AGridDesc_M_K
,
BGridDesc_BK0_N_BK1
,
BGridDesc_N_K
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
EGridDesc_M_N
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
...
@@ -561,6 +357,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -561,6 +357,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
LoopSched
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -568,27 +371,30 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -568,27 +371,30 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
void
*
p_b_grid
,
const
void
*
p_b_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
void
*
p_e_grid
,
std
::
vector
<
index_t
>
a_ms_ns_lengths
,
const
std
::
vector
<
index_t
>
&
a_ms_ns_lengths
,
std
::
vector
<
index_t
>
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>
&
a_ms_ks_strides
,
std
::
vector
<
index_t
>
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>
&
b_ns_ks_lengths
,
std
::
vector
<
index_t
>
b_ns_ks_strides
,
const
std
::
vector
<
index_t
>
&
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
&
ds_ms_ns_lengths
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
&
ds_ms_ns_strides
,
std
::
vector
<
index_t
>
e_ms_ns_lengths
,
const
std
::
vector
<
index_t
>
&
e_ms_ns_lengths
,
std
::
vector
<
index_t
>
e_ms_ns_strides
,
const
std
::
vector
<
index_t
>
&
e_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
CDEElementwiseOperation
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
// FIXME
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
a_ms_ns_lengths
,
a_ms_ks_strides
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
b_ns_ks_lengths
,
b_ns_ks_strides
)},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_ms_ns_lengths
,
e_ms_ns_strides
)},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
Make
AGridDescriptor_AK0_M_AK1
(
a_
ms_ns_lengths
,
a
_m
s
_k
s_strides
)},
GridwiseGemm
::
MakeDefault
AGridDescriptor_AK0_M_AK1
(
a_
grid_desc
_m_k
_
)},
b_grid_desc_bk0_n_bk1_
{
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
Make
BGridDescriptor_BK0_N_BK1
(
b_
ns_ks_lengths
,
b
_n
s
_k
s_strides
)},
GridwiseGemm
::
MakeDefault
BGridDescriptor_BK0_N_BK1
(
b_
grid_desc
_n_k
_
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_ms_ns_lengths
,
e_ms_ns_strides
)},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
...
@@ -601,8 +407,22 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -601,8 +407,22 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
ds_nz_stride_
{},
ds_nz_stride_
{},
e_nz_stride_
{}
e_nz_stride_
{}
{
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
// populate pointer, batch stride, desc for Ds
b_grid_desc_bk0_n_bk1_
,
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
// D desc
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
ds_ms_ns_lengths
[
i
],
ds_ms_ns_strides
[
i
]);
});
// populate desc for Ds/E
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
block_2_etile_map_
))
{
{
...
@@ -610,18 +430,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -610,18 +430,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
e_grid_desc_m_n_
);
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
const
auto
d_grid_desc_m_n
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
ds_ms_ns_lengths
[
i
],
ds_ms_ns_strides
[
i
]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
(
i
)
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
d_grid_desc_m_n
);
});
}
}
// for sanity check of vector memory access
// for sanity check of vector memory access
...
@@ -639,6 +450,15 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -639,6 +450,15 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
e_nz_stride_
=
e_ms_ns_strides
[
NumDimM
+
NumDimN
-
1
];
e_nz_stride_
=
e_ms_ns_strides
[
NumDimM
+
NumDimN
-
1
];
}
}
void
Print
()
const
{
std
::
cout
<<
"A[M, K]: "
<<
a_grid_desc_m_k_
<<
std
::
endl
;
std
::
cout
<<
"B[N, K]: "
<<
b_grid_desc_n_k_
<<
std
::
endl
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
std
::
cout
<<
"Ds[M, N]: "
<<
ds_grid_desc_m_n_
[
i
]
<<
std
::
endl
;
});
std
::
cout
<<
"E[M, N]: "
<<
e_grid_desc_m_n_
<<
std
::
endl
;
}
// private:
// private:
// pointers
// pointers
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
...
@@ -646,20 +466,22 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -646,20 +466,22 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
NumDTensor
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
// FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N
e_grid_desc_m_n_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-e-tile map
// block-to-e-tile map
typename
GridwiseGemm
::
Default
Block2ETileMap
block_2_etile_map_
;
Block2ETileMap
block_2_etile_map_
;
// element-wise op
// element-wise op
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
...
@@ -684,29 +506,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -684,29 +506,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
#if 0
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
{
arg
.
b_grid_desc_n_k_
,
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
arg
.
ds_grid_desc_m_n_
,
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
arg
.
block_2_etile_map_
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"
);
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
...
@@ -728,9 +535,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -728,9 +535,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
CDEElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
ck
::
StaticallyIndexedArray
<
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
has_main_loop
>
;
...
@@ -754,18 +559,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -754,18 +559,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg
.
block_2_etile_map_
);
arg
.
block_2_etile_map_
);
};
};
float
ave_time
=
0
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
}
else
else
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
return
ave_time
;
}
}
// polymorphic
// polymorphic
...
@@ -776,12 +577,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -776,12 +577,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
}
}
};
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
...
@@ -789,8 +584,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -789,8 +584,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return
false
;
return
false
;
}
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
arg
.
block_2_etile_map_
))
{
{
...
@@ -878,14 +674,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -878,14 +674,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
void
*
p_b
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
void
*
p_e
,
std
::
vector
<
index_t
>
a_ms_ns_lengths
,
const
std
::
vector
<
index_t
>
&
a_ms_ns_lengths
,
std
::
vector
<
index_t
>
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>
&
a_ms_ks_strides
,
std
::
vector
<
index_t
>
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>
&
b_ns_ks_lengths
,
std
::
vector
<
index_t
>
b_ns_ks_strides
,
const
std
::
vector
<
index_t
>
&
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
&
ds_ms_ns_lengths
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
&
ds_ms_ns_strides
,
std
::
vector
<
index_t
>
e_ms_ns_lengths
,
const
std
::
vector
<
index_t
>
&
e_ms_ns_lengths
,
std
::
vector
<
index_t
>
e_ms_ns_strides
,
const
std
::
vector
<
index_t
>
&
e_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
CDEElementwiseOperation
cde_element_op
)
...
@@ -915,14 +711,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -915,14 +711,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
void
*
p_b
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
void
*
p_e
,
std
::
vector
<
index_t
>
a_ms_ns_lengths
,
const
std
::
vector
<
index_t
>
&
a_ms_ns_lengths
,
std
::
vector
<
index_t
>
a_ms_ks_strides
,
const
std
::
vector
<
index_t
>
&
a_ms_ks_strides
,
std
::
vector
<
index_t
>
b_ns_ks_lengths
,
const
std
::
vector
<
index_t
>
&
b_ns_ks_lengths
,
std
::
vector
<
index_t
>
b_ns_ks_strides
,
const
std
::
vector
<
index_t
>
&
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
&
ds_ms_ns_lengths
,
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
&
ds_ms_ns_strides
,
std
::
vector
<
index_t
>
e_ms_ns_lengths
,
const
std
::
vector
<
index_t
>
&
e_ms_ns_lengths
,
std
::
vector
<
index_t
>
e_ms_ns_strides
,
const
std
::
vector
<
index_t
>
&
e_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
CDEElementwiseOperation
cde_element_op
)
override
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
dfbb659a
...
@@ -434,9 +434,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -434,9 +434,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
CDEElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
ck
::
StaticallyIndexedArray
<
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
has_main_loop
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment