Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7a62d4a7
Commit
7a62d4a7
authored
May 04, 2023
by
Po-Yen, Chen
Browse files
Rename variables M/N/KRaw to M/N/K
parent
41449b67
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
75 additions
and
87 deletions
+75
-87
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+75
-87
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
7a62d4a7
...
@@ -82,19 +82,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -82,19 +82,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
static
auto
index_t
M
Raw
,
index_t
MPad
,
index_t
K
Raw
,
index_t
KPad
,
index_t
StrideA
)
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
)
{
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
make_tuple
(
StrideA
,
I1
));
}
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
make_tuple
(
I1
,
StrideA
));
}
}
}();
}();
...
@@ -108,8 +106,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -108,8 +106,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const
auto
a_grid_desc_m_k
=
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
M
Raw
,
MPad
-
M
Raw
),
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
K
Raw
,
KPad
-
K
Raw
)),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -126,14 +124,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -126,14 +124,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec
==
GemmSpecialization
::
MNPadding
)
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
// pad M, but not K
// pad M, but not K
assert
(
K
Raw
%
AK1
==
0
);
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
Raw
/
AK1
;
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
M
Raw
,
MPad
-
M
Raw
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -147,17 +145,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -147,17 +145,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const
auto
AK0
=
KPad
/
AK1
;
const
auto
AK0
=
KPad
/
AK1
;
const
auto
a_grid_desc_m_k
=
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_right_pad_transform
(
KRaw
,
KPad
-
KRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
Raw
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -166,14 +163,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -166,14 +163,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
else
else
{
{
// not pad M or K
// not pad M or K
assert
(
K
Raw
%
AK1
==
0
);
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
Raw
/
AK1
;
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
Raw
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -181,19 +178,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -181,19 +178,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
}
}
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
static
auto
index_t
K
Raw
,
index_t
KPad
,
index_t
N
Raw
,
index_t
NPad
,
index_t
StrideB
)
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
)
{
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
I1
,
StrideB
));
make_tuple
(
I1
,
StrideB
));
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
StrideB
,
I1
));
make_tuple
(
StrideB
,
I1
));
}
}
}();
}();
...
@@ -207,8 +202,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -207,8 +202,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const
auto
b_grid_desc_n_k
=
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
N
Raw
,
NPad
-
N
Raw
),
make_tuple
(
make_right_pad_transform
(
N
,
NPad
-
N
),
make_right_pad_transform
(
K
Raw
,
KPad
-
K
Raw
)),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -225,14 +220,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -225,14 +220,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec
==
GemmSpecialization
::
MNPadding
)
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
// pad N, but not K
// pad N, but not K
assert
(
K
Raw
%
BK1
==
0
);
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
Raw
/
BK1
;
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
N
Raw
,
NPad
-
N
Raw
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -246,17 +241,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -246,17 +241,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const
auto
BK0
=
KPad
/
BK1
;
const
auto
BK0
=
KPad
/
BK1
;
const
auto
b_grid_desc_n_k
=
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_tuple
(
make_pass_through_transform
(
N
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_right_pad_transform
(
KRaw
,
KPad
-
KRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
Raw
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -265,14 +259,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -265,14 +259,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
else
else
{
{
// not pad N or K
// not pad N or K
assert
(
K
Raw
%
BK1
==
0
);
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
Raw
/
BK1
;
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
Raw
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -281,18 +275,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -281,18 +275,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
}
static
auto
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
Raw
,
index_t
MPad
,
index_t
N
Raw
,
index_t
NPad
,
index_t
StrideC
)
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideC
)
{
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
make_tuple
(
StrideC
,
I1
));
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
make_tuple
(
I1
,
StrideC
));
}
}
}();
}();
...
@@ -300,10 +292,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -300,10 +292,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
{
// pad M and N
// pad M and N
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
-
MRaw
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_right_pad_transform
(
NRaw
,
NPad
-
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
...
@@ -313,8 +304,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -313,8 +304,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad M, but not N
// pad M, but not N
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
-
MRaw
),
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_pass_through_transform
(
N
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
...
@@ -324,8 +314,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -324,8 +314,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// pad N, but not M
// pad N, but not M
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_right_pad_transform
(
NRaw
,
NPad
-
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
...
@@ -393,9 +382,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -393,9 +382,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
Argument
(
const
ADataType
*
p_a_grid
,
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
index_t
M
Raw
,
index_t
M
,
index_t
N
Raw
,
index_t
N
,
index_t
K
Raw
,
index_t
K
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
...
@@ -406,29 +395,28 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -406,29 +395,28 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
M
Raw
,
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
M
,
GridwiseGemm
::
CalculateMPadded
(
M
Raw
),
GridwiseGemm
::
CalculateMPadded
(
M
),
K
Raw
,
K
,
GridwiseGemm
::
CalculateKPadded
(
K
Raw
),
GridwiseGemm
::
CalculateKPadded
(
K
),
StrideA
)},
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
K
Raw
,
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
K
,
GridwiseGemm
::
CalculateKPadded
(
K
Raw
),
GridwiseGemm
::
CalculateKPadded
(
K
),
N
Raw
,
N
,
GridwiseGemm
::
CalculateNPadded
(
N
Raw
),
GridwiseGemm
::
CalculateNPadded
(
N
),
StrideB
)},
StrideB
)},
c_grid_desc_m_n_
{
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
M
,
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateMPadded
(
MRaw
),
N
,
NRaw
,
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateNPadded
(
NRaw
),
StrideC
)},
StrideC
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
kraw_
{
K
Raw
}
kraw_
{
K
}
{
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b_grid_desc_bk0_n_bk1_
,
...
@@ -608,9 +596,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -608,9 +596,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
CDataType
*
p_c
,
index_t
M
Raw
,
index_t
M
,
index_t
N
Raw
,
index_t
N
,
index_t
K
Raw
,
index_t
K
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
...
@@ -621,9 +609,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -621,9 +609,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
p_c
,
p_c
,
M
Raw
,
M
,
N
Raw
,
N
,
K
Raw
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
...
@@ -638,9 +626,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -638,9 +626,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
void
*
p_c
,
void
*
p_c
,
index_t
M
Raw
,
index_t
M
,
index_t
N
Raw
,
index_t
N
,
index_t
K
Raw
,
index_t
K
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
...
@@ -651,9 +639,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -651,9 +639,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
M
Raw
,
M
,
N
Raw
,
N
,
K
Raw
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment