Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
edbed47e
Commit
edbed47e
authored
Aug 22, 2022
by
Anthony Chang
Browse files
refactor MatrixPadder; rename to GemmPadder
parent
f91dad8e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
138 deletions
+22
-138
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
+22
-138
No files found.
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
View file @
edbed47e
...
@@ -78,7 +78,6 @@ PadTensorDescriptor(const TensorDesc_GRaw_MRaw_NRaw& tensor_desc_graw_mraw_nraw,
...
@@ -78,7 +78,6 @@ PadTensorDescriptor(const TensorDesc_GRaw_MRaw_NRaw& tensor_desc_graw_mraw_nraw,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
}
// M/N/K/OPerTileType could be index_t or Number<>
// M/N/K/OPerTileType could be index_t or Number<>
template
<
GemmSpecialization
GemmSpec
,
template
<
GemmSpecialization
GemmSpec
,
typename
MPerTileType
,
typename
MPerTileType
,
...
@@ -152,161 +151,37 @@ template <GemmSpecialization GemmSpec,
...
@@ -152,161 +151,37 @@ template <GemmSpecialization GemmSpec,
typename
MPerTileType
,
typename
MPerTileType
,
typename
NPerTileType
,
typename
NPerTileType
,
typename
KPerTileType
>
typename
KPerTileType
>
struct
Matrix
Padder
struct
Gemm
Padder
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
bool
PadM
=
static
constexpr
auto
I1
=
Number
<
1
>
{};
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
static
constexpr
auto
I2
=
Number
<
2
>
{};
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
);
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
bool
PadN
=
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
);
static
constexpr
bool
PadK
=
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
);
template
<
typename
ADesc_MRaw_KRaw
>
template
<
typename
ADesc_MRaw_KRaw
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
PadADescriptor_M_K
(
const
ADesc_MRaw_KRaw
&
a_desc_mraw_kraw
)
const
PadADescriptor_M_K
(
const
ADesc_MRaw_KRaw
&
a_desc_mraw_kraw
)
const
{
{
const
auto
MRaw
=
a_desc_mraw_kraw
.
GetLength
(
I0
);
return
PadTensorDescriptor
<
PadM
,
PadK
>
(
a_desc_mraw_kraw
,
MPerTile_
,
KPerTile_
);
const
auto
KRaw
=
a_desc_mraw_kraw
.
GetLength
(
I1
);
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerTile_
)
*
MPerTile_
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerTile_
)
*
KPerTile_
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
return
transform_tensor_descriptor
(
a_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
return
transform_tensor_descriptor
(
a_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
KRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
return
transform_tensor_descriptor
(
a_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or K
return
a_desc_mraw_kraw
;
}
}
}
template
<
typename
BDesc_NRaw_KRaw
>
template
<
typename
BDesc_NRaw_KRaw
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
PadBDescriptor_N_K
(
const
BDesc_NRaw_KRaw
&
b_desc_nraw_kraw
)
const
PadBDescriptor_N_K
(
const
BDesc_NRaw_KRaw
&
b_desc_nraw_kraw
)
const
{
{
const
auto
NRaw
=
b_desc_nraw_kraw
.
GetLength
(
I0
);
return
PadTensorDescriptor
<
PadN
,
PadK
>
(
b_desc_nraw_kraw
,
NPerTile_
,
KPerTile_
);
const
auto
KRaw
=
b_desc_nraw_kraw
.
GetLength
(
I1
);
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerTile_
)
*
NPerTile_
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerTile_
)
*
KPerTile_
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
return
transform_tensor_descriptor
(
b_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
return
transform_tensor_descriptor
(
b_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_pass_through_transform
(
KRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
return
transform_tensor_descriptor
(
b_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad N or K
return
b_desc_nraw_kraw
;
}
}
}
template
<
typename
CDesc_MRaw_NRaw
>
template
<
typename
CDesc_MRaw_NRaw
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
PadCDescriptor_M_N
(
const
CDesc_MRaw_NRaw
&
c_desc_mraw_nraw
)
const
PadCDescriptor_M_N
(
const
CDesc_MRaw_NRaw
&
c_desc_mraw_nraw
)
const
{
{
const
auto
MRaw
=
c_desc_mraw_nraw
.
GetLength
(
I0
);
return
PadTensorDescriptor
<
PadM
,
PadN
>
(
c_desc_mraw_nraw
,
MPerTile_
,
NPerTile_
);
const
auto
NRaw
=
c_desc_mraw_nraw
.
GetLength
(
I1
);
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerTile_
)
*
MPerTile_
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerTile_
)
*
NPerTile_
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_desc_mraw_nraw
;
}
}
}
MPerTileType
MPerTile_
;
MPerTileType
MPerTile_
;
...
@@ -314,6 +189,15 @@ struct MatrixPadder
...
@@ -314,6 +189,15 @@ struct MatrixPadder
KPerTileType
KPerTile_
;
KPerTileType
KPerTile_
;
};
};
// Alias of GemmPadder; to deprecate
template
<
GemmSpecialization
GemmSpec
,
typename
MPerTileType
,
typename
NPerTileType
,
typename
KPerTileType
>
struct
MatrixPadder
:
public
GemmPadder
<
GemmSpec
,
MPerTileType
,
NPerTileType
,
KPerTileType
>
{
};
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment