Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b3054fea
Commit
b3054fea
authored
Jan 22, 2025
by
Adam Osewski
Browse files
Merge branch 'develop' into aosewski/ck_tile_gemm_policy
parents
7cbc1492
052a7265
Changes
54
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
339 additions
and
376 deletions
+339
-376
include/ck_tile/core/utility/unary_element_function.hpp
include/ck_tile/core/utility/unary_element_function.hpp
+9
-7
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+18
-36
include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp
...k_tile/host/convolution_host_tensor_descriptor_helper.hpp
+27
-57
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+34
-1
include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp
..._tile/host/reference/reference_rowwise_quantization2d.hpp
+1
-1
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+6
-3
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+52
-48
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
+115
-28
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
+68
-187
include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp
...ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp
+2
-1
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp
...ps/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp
+1
-1
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp
...ps/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp
+1
-1
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+4
-4
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+1
-1
No files found.
include/ck_tile/core/utility/unary_element_function.hpp
View file @
b3054fea
...
...
@@ -51,16 +51,18 @@ struct composes<F>
template
<
typename
...
Ts
>
__host__
__device__
composes
(
Ts
&&
...)
->
composes
<
remove_cvref_t
<
Ts
>
...
>
;
template
<
typename
To
>
template
<
typename
SaturateType
>
struct
saturates
{
template
<
typename
From
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
From
&
from
)
const
->
std
::
enable_if_t
<
std
::
is_arithmetic_v
<
From
>
,
From
>
// NOTE: this function does not return SaturateType value
// it is user's responsiblity to do further cast or not
template
<
typename
AccType
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
AccType
&
a_
)
const
->
std
::
enable_if_t
<
std
::
is_arithmetic_v
<
AccType
>
,
AccType
>
{
return
clamp
(
from
,
type_convert
<
From
>
(
numeric
<
To
>::
lowest
()),
type_convert
<
From
>
(
numeric
<
To
>::
max
()));
return
clamp
(
a_
,
type_convert
<
AccType
>
(
numeric
<
SaturateType
>::
lowest
()),
type_convert
<
AccType
>
(
numeric
<
SaturateType
>::
max
()));
}
};
...
...
include/ck_tile/host/check_err.hpp
View file @
b3054fea
...
...
@@ -28,14 +28,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
std
::
is_same_v
<
ComputeDataType
,
F8
>
||
std
::
is_same_v
<
ComputeDataType
,
F16
>
||
std
::
is_same_v
<
ComputeDataType
,
BF16
>
||
std
::
is_same_v
<
ComputeDataType
,
F32
>
||
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
,
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
)
if
constexpr
(
is_any_of
<
ComputeDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
...
...
@@ -44,14 +41,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error
=
std
::
pow
(
2
,
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
std
::
is_same_v
<
OutDataType
,
F8
>
||
std
::
is_same_v
<
OutDataType
,
F16
>
||
std
::
is_same_v
<
OutDataType
,
BF16
>
||
std
::
is_same_v
<
OutDataType
,
F32
>
||
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
,
static_assert
(
is_any_of
<
OutDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
)
if
constexpr
(
is_any_of
<
OutDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
...
...
@@ -61,14 +55,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
std
::
is_same_v
<
AccDataType
,
F8
>
||
std
::
is_same_v
<
AccDataType
,
F16
>
||
std
::
is_same_v
<
AccDataType
,
BF16
>
||
std
::
is_same_v
<
AccDataType
,
F32
>
||
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
,
static_assert
(
is_any_of
<
AccDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
)
if
constexpr
(
is_any_of
<
AccDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
...
...
@@ -89,15 +80,12 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
std
::
is_same_v
<
ComputeDataType
,
F8
>
||
std
::
is_same_v
<
ComputeDataType
,
F16
>
||
std
::
is_same_v
<
ComputeDataType
,
BF16
>
||
std
::
is_same_v
<
ComputeDataType
,
F32
>
||
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
,
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
double
compute_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
)
if
constexpr
(
is_any_of
<
ComputeDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
...
...
@@ -106,14 +94,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
std
::
is_same_v
<
OutDataType
,
F8
>
||
std
::
is_same_v
<
OutDataType
,
F16
>
||
std
::
is_same_v
<
OutDataType
,
BF16
>
||
std
::
is_same_v
<
OutDataType
,
F32
>
||
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
,
static_assert
(
is_any_of
<
OutDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
)
if
constexpr
(
is_any_of
<
OutDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
...
...
@@ -123,14 +108,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
std
::
is_same_v
<
AccDataType
,
F8
>
||
std
::
is_same_v
<
AccDataType
,
F16
>
||
std
::
is_same_v
<
AccDataType
,
BF16
>
||
std
::
is_same_v
<
AccDataType
,
F32
>
||
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
,
static_assert
(
is_any_of
<
AccDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
)
if
constexpr
(
is_any_of
<
AccDataType
,
I8
,
I32
,
int
>::
value
)
{
return
0
;
}
...
...
include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp
View file @
b3054fea
...
...
@@ -14,57 +14,41 @@ namespace detail {
template
<
typename
OldLayout
>
CK_TILE_HOST
std
::
vector
<
std
::
size_t
>
get_layout_transpose_gnchw_to_old
()
{
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKW
>
)
using
namespace
ck_tile
::
tensor_layout
::
convolution
;
if
constexpr
(
is_any_of
<
OldLayout
,
GNCW
,
GKCX
,
GNKW
>::
value
)
{
return
{
0
,
1
,
2
,
3
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCHW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCYX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKHW
>
)
else
if
constexpr
(
is_any_of
<
OldLayout
,
GNCHW
,
GKCYX
,
GNKHW
>::
value
)
{
return
{
0
,
1
,
2
,
3
,
4
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCDHW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCZYX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKDHW
>
)
else
if
constexpr
(
is_any_of
<
OldLayout
,
GNCDHW
,
GKCZYX
,
GNKDHW
>::
value
)
{
return
{
0
,
1
,
2
,
3
,
4
,
5
};
}
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWK
>
)
if
constexpr
(
is_any_of
<
OldLayout
,
GNWC
,
GKXC
,
GNWK
>::
value
)
{
return
{
0
,
1
,
3
,
2
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKYXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWK
>
)
else
if
constexpr
(
is_any_of
<
OldLayout
,
GNHWC
,
GKYXC
,
GNHWK
>::
value
)
{
return
{
0
,
1
,
4
,
2
,
3
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKZYXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWK
>
)
else
if
constexpr
(
is_any_of
<
OldLayout
,
GNDHWC
,
GKZYXC
,
GNDHWK
>::
value
)
{
return
{
0
,
1
,
5
,
2
,
3
,
4
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGK
>
)
else
if
constexpr
(
is_any_of
<
OldLayout
,
NWGC
,
KXGC
,
NWGK
>::
value
)
{
return
{
2
,
0
,
3
,
1
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGK
>
)
else
if
constexpr
(
is_any_of
<
OldLayout
,
NHWGC
,
KYXGC
,
NHWGK
>::
value
)
{
return
{
3
,
0
,
4
,
1
,
2
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGK
>
)
else
if
constexpr
(
is_any_of
<
OldLayout
,
NDHWGC
,
KZYXGC
,
NDHWGK
>::
value
)
{
return
{
4
,
0
,
5
,
1
,
2
,
3
};
}
...
...
@@ -83,11 +67,11 @@ template <typename InLayout>
CK_TILE_HOST
HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
using
namespace
ck_tile
::
tensor_layout
::
convolution
;
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCW
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCHW
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCDHW
>
)
if
constexpr
(
is_any_of
<
InLayout
,
GNCW
,
GNCHW
,
GNCDHW
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
...
...
@@ -97,9 +81,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWC
>
)
else
if
constexpr
(
is_any_of
<
InLayout
,
GNWC
,
GNHWC
,
GNDHWC
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
...
...
@@ -109,9 +91,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGC
>
)
else
if
constexpr
(
is_any_of
<
InLayout
,
NWGC
,
NHWGC
,
NDHWGC
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
...
...
@@ -139,11 +119,11 @@ template <typename WeiLayout>
CK_TILE_HOST
HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
using
namespace
ck_tile
::
tensor_layout
::
convolution
;
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXC
>
)
if
constexpr
(
is_any_of
<
WeiLayout
,
KXC
,
KYXC
,
KZYXC
>::
value
)
{
if
(
param
.
G_
!=
1
)
{
...
...
@@ -157,9 +137,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCX
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCYX
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCZYX
>
)
else
if
constexpr
(
is_any_of
<
WeiLayout
,
GKCX
,
GKCYX
,
GKCZYX
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
...
...
@@ -169,9 +147,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKYXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKZYXC
>
)
else
if
constexpr
(
is_any_of
<
WeiLayout
,
GKXC
,
GKYXC
,
GKZYXC
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
...
...
@@ -181,9 +157,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXGC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXGC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXGC
>
)
else
if
constexpr
(
is_any_of
<
WeiLayout
,
KXGC
,
KYXGC
,
KZYXGC
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
...
...
@@ -211,11 +185,11 @@ template <typename OutLayout>
CK_TILE_HOST
HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
using
namespace
ck_tile
::
tensor_layout
::
convolution
;
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKW
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKHW
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKDHW
>
)
if
constexpr
(
is_any_of
<
OutLayout
,
GNKW
,
GNKHW
,
GNKDHW
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
...
...
@@ -226,9 +200,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
// separate from legacy code above
else
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWK
>
)
else
if
constexpr
(
is_any_of
<
OutLayout
,
GNWK
,
GNHWK
,
GNDHWK
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
...
...
@@ -238,9 +210,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGK
>
)
else
if
constexpr
(
is_any_of
<
OutLayout
,
NWGK
,
NHWGK
,
NDHWGK
>::
value
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
...
...
include/ck_tile/host/host_tensor.hpp
View file @
b3054fea
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -678,4 +678,37 @@ struct HostTensor
Descriptor
mDesc
;
Data
mData
;
};
template
<
typename
TLayout
>
auto
host_tensor_descriptor
(
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
TLayout
layout
)
{
using
namespace
ck_tile
::
literals
;
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
}
template
<
typename
TLayout
>
auto
get_default_stride
(
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
TLayout
layout
)
{
if
(
stride
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp
View file @
b3054fea
...
...
@@ -22,7 +22,7 @@ CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor<XDataType>&
// scale = amax / 127 for int8
auto
v_scale
=
type_convert
<
XDataType
>
(
scale_m
(
m
));
auto
v_qx
=
v_x
/
v_scale
;
qx_m_n
(
m
,
n
)
=
saturates
<
QXDataType
>
{}(
v_qx
);
qx_m_n
(
m
,
n
)
=
type_convert
<
QXDataType
>
(
saturates
<
QXDataType
>
{}(
v_qx
)
)
;
}
};
...
...
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
b3054fea
...
...
@@ -101,9 +101,12 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE
void
operator
()(
BatchedGemmKernelArgs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
/
kargs
.
KBatch
);
const
auto
i_k
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
-
i_batch
*
kargs
.
KBatch
);
const
auto
[
iM
,
iN
]
=
TilePartitioner
::
GetOutputTileIndex
(
blockIdx
.
x
,
blockIdx
.
y
);
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
iM
*
TilePartitioner
::
MPerBlock
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
iN
*
TilePartitioner
::
NPerBlock
);
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
/
kargs
.
KBatch
);
const
auto
i_k
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
-
i_batch
*
kargs
.
KBatch
);
const
typename
Base
::
SplitKBatchOffset
splitk_batch_offset
(
kargs
,
i_k
);
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
b3054fea
...
...
@@ -175,7 +175,7 @@ struct GemmKernel
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
K
%
TilePartitioner
::
k
K
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
if
(
kargs
.
K
%
TilePartitioner
::
K
PerBlock
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
{
std
::
cerr
<<
"Can't support K that is not a multiple of KPerBlock"
" without padding!"
...
...
@@ -190,7 +190,7 @@ struct GemmKernel
}
else
{
if
(
kargs
.
M
%
TilePartitioner
::
k
M
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
if
(
kargs
.
M
%
TilePartitioner
::
M
PerBlock
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
{
std
::
cerr
<<
"Can't support M that is not a multiple of MPerBlock"
" without padding!"
...
...
@@ -206,7 +206,7 @@ struct GemmKernel
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
N
%
TilePartitioner
::
k
N
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
if
(
kargs
.
N
%
TilePartitioner
::
N
PerBlock
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
{
std
::
cerr
<<
"Can't support N that is not a multiple of NPerBlock"
" without padding!"
...
...
@@ -221,7 +221,7 @@ struct GemmKernel
}
else
{
if
(
kargs
.
K
%
TilePartitioner
::
k
K
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
if
(
kargs
.
K
%
TilePartitioner
::
K
PerBlock
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
{
std
::
cerr
<<
"Can't support K that is not a multiple of KPerBlock"
" without padding!"
...
...
@@ -237,7 +237,7 @@ struct GemmKernel
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
N
%
TilePartitioner
::
k
N
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
if
(
kargs
.
N
%
TilePartitioner
::
N
PerBlock
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
{
std
::
cerr
<<
"Can't support N that is not a multiple of NPerBlock"
" without padding!"
...
...
@@ -252,7 +252,7 @@ struct GemmKernel
}
else
{
if
(
kargs
.
M
%
TilePartitioner
::
k
M
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
if
(
kargs
.
M
%
TilePartitioner
::
M
PerBlock
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
{
std
::
cerr
<<
"Can't support M that is not a multiple of MPerBlock"
" without padding!"
...
...
@@ -357,17 +357,17 @@ struct GemmKernel
const
auto
&
a_tensor_view
=
views
.
at
(
I0
);
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{}
,
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kK
>
{},
number
<
TilePartitioner
::
kM
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadM
>
{});
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
KPerBlock
>
{}
,
number
<
TilePartitioner
::
MPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadM
>
{});
}
}();
...
...
@@ -375,17 +375,17 @@ struct GemmKernel
const
auto
&
b_tensor_view
=
views
.
at
(
I1
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{}
,
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kK
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
KPerBlock
>
{}
,
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
}();
...
...
@@ -394,17 +394,17 @@ struct GemmKernel
const
auto
&
c_tensor_view
=
views
.
at
(
I2
);
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{}
,
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{}
,
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
...
...
@@ -422,40 +422,40 @@ struct GemmKernel
const
auto
&
a_block_window
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
return
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{}
,
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_m
,
0
});
}
else
{
return
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kK
>
{},
number
<
TilePartitioner
::
kM
>
{}),
{
0
,
i_m
});
return
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
KPerBlock
>
{}
,
number
<
TilePartitioner
::
MPerBlock
>
{}),
{
0
,
i_m
});
}
}();
const
auto
&
b_block_window
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
return
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{}
,
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_n
,
0
});
}
else
{
return
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kK
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
0
,
i_n
});
return
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
KPerBlock
>
{}
,
number
<
TilePartitioner
::
NPerBlock
>
{}),
{
0
,
i_n
});
}
}();
auto
c_block_window
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
k
M
>
{},
number
<
TilePartitioner
::
k
N
>
{}),
make_tuple
(
number
<
TilePartitioner
::
M
PerBlock
>
{},
number
<
TilePartitioner
::
N
PerBlock
>
{}),
{
i_m
,
i_n
});
return
make_tuple
(
a_block_window
,
b_block_window
,
c_block_window
);
...
...
@@ -486,6 +486,7 @@ struct GemmKernel
// Create Gemm tensor views, pad views and tile windows
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
<
DstInMemOp
>
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
splitk_batch_offset
);
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
...
...
@@ -515,7 +516,10 @@ struct GemmKernel
CK_TILE_DEVICE
void
operator
()(
GemmKernelArgs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
[
iM
,
iN
]
=
TilePartitioner
::
GetOutputTileIndex
(
blockIdx
.
x
,
blockIdx
.
y
);
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
iM
*
TilePartitioner
::
MPerBlock
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
iN
*
TilePartitioner
::
NPerBlock
);
const
SplitKBatchOffset
splitk_batch_offset
(
kargs
);
// options
const
ADataType
*
a_ptr
=
...
...
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
View file @
b3054fea
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockGemmShape_
>
struct
GemmTilePartitioner
/** @brief Struct representing 2D block index mapping into 3D output tile space. */
template
<
typename
BlockGemmShapeType
>
struct
GemmTile2DPartitioner
{
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape
_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape
Type
>
;
static
constexpr
index_t
k
M
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
k
N
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
k
K
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
M
PerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
N
PerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
K
PerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
batch_size
)
/** @brief Returns 3D grid size. */
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
batch_size
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
dim3
{
index_t
GridDimX
=
(
M
+
kM
-
1
)
/
kM
;
index_t
GridDimY
=
(
N
+
kN
-
1
)
/
kN
;
index_t
GridDimZ
=
batch_size
;
const
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
const
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
GridDimZ
=
batch_size
;
return
dim3
(
GridDimX
,
GridDimY
,
GridDimZ
);
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
/**
* @brief Returns the number of loops.
* @param [in] K is dimension
*/
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
{
return
integer_divide_ceil
(
K
,
k
K
);
return
integer_divide_ceil
(
K
,
K
PerBlock
);
}
CK_TILE_DEVICE
auto
operator
()()
/**
* @brief The function returns 2D output tile space.
* @param [in] blockIdx is blockIdx.x
* @param [in] blockIdy is blockIdx.y
* @return Returns the output tile indexes.
*/
CK_TILE_DEVICE
static
constexpr
auto
GetOutputTileIndex
(
index_t
blockIdx
,
index_t
blockIdy
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kM
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockId
x
.
y
*
kN
);
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockId
y
);
return
make_tuple
(
iM
,
iN
);
}
};
template
<
typename
BlockGemmShape_
>
/**
* @brief Struct representing 1D block index mapping into 2D output tile space.
*/
template
<
typename
BlockGemmShapeType
>
struct
GemmTile1DPartitioner
{
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape
_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape
Type
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
)
/** @brief delete default ctr with no any object */
constexpr
GemmTile1DPartitioner
()
noexcept
=
delete
;
/** @brief constructs an object that does contain a N value. */
constexpr
GemmTile1DPartitioner
(
index_t
N
)
noexcept
{
N_
=
N
;
}
/** @brief Returns 1D grid size. */
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
dim3
{
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
const
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
return
dim3
(
GridDimX
*
GridDimY
,
1
,
1
);
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetNBlock
(
index_t
N
)
/**
* @brief Returns the number of blocks in N.
* @param [in] N is dimension
*/
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetNBlock
(
index_t
N
)
noexcept
->
index_t
{
return
integer_divide_ceil
(
N
,
NPerBlock
);
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
/**
* @brief Returns the number of loops.
* @param [in] K is dimension
*/
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
{
return
integer_divide_ceil
(
K
,
KPerBlock
);
}
CK_TILE_DEVICE
auto
operator
()(
index_t
blockOffset
,
index_t
NBlockSize
)
/**
* @brief The function returns 2D output tile space.
* @param [in] blockIdx is blockIdx.x - block_start.
* */
CK_TILE_DEVICE
static
constexpr
auto
GetOutputTileIndex
(
index_t
blockIdx
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
index_t
NBlock
=
GetNBlock
(
N_
);
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
/
NBlock
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
-
(
iM
)
*
NBlock
);
return
make_tuple
(
iM
,
iN
);
}
private:
CK_TILE_DEVICE
static
index_t
N_
;
};
/**
* @brief `GemmTile1DPartitioner::GetOutputTileIndex`'s std::false specialization,
* checking expression validity in-place for ill-formed.
*/
template
<
typename
,
typename
=
void
>
struct
HasFnOneArgImpl
:
std
::
false_type
{
};
/**
* @brief `GemmTile1DPartitioner::GetOutputTileIndex`'s std::true specialization,
* checking expression validity in-place for well-formed.
* @note: `1` - a constant value indicating the number of parameters in the function.
*/
template
<
typename
T
>
struct
HasFnOneArgImpl
<
T
,
std
::
void_t
<
decltype
(
std
::
declval
<
T
>
().
GetOutputTileIndex
(
1
))
>>
:
std
::
true_type
{
};
/**
* @brief Struct used to calculate offseted tile indexes.
* @note: The struct supports the 1D-Partitioner mechanism,
* enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed,
* otherwise std::false_type.
*/
template
<
typename
PartitionerFn
,
typename
=
typename
std
::
enable_if_t
<
HasFnOneArgImpl
<
PartitionerFn
>{}
>>
struct
OffsettedTile1DPartitioner
{
/**
* @brief The function subtracts the block's start (offset) from 1D raw-indexes.
* @param [in] block_start is `blockIdx.x - block_start`.
* @return Returns a `tuple` [Im, In] shifted index, used to shift 1d-tile index.
*/
[[
nodiscard
]]
CK_TILE_DEVICE
static
constexpr
auto
GetOffsetedTileIndex
(
index_t
block_start
,
index_t
N
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
index_t
iM
=
__builtin_amdgcn_readfirstlane
((
blockIdx
.
x
-
blockOffset
)
/
GetNBlock
(
NBlockSize
)
*
MPerBlock
);
index_t
iN
=
__builtin_amdgcn_readfirstlane
((
blockIdx
.
x
-
blockOffset
)
%
GetNBlock
(
NBlockSize
)
*
NPerBlock
);
const
auto
[
iM
,
iN
]
=
PartitionerFn
(
N
).
GetOutputTileIndex
(
blockIdx
.
x
-
block_start
);
return
make_tuple
(
iM
,
iN
);
}
};
...
...
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
View file @
b3054fea
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/host.hpp"
namespace
ck_tile
{
struct
GroupedGemmHostArgs
struct
GroupedGemmHostArgs
:
public
ck_tile
::
GemmHostArgs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
CK_TILE_HOST
GroupedGemmHostArgs
()
noexcept
=
default
;
CK_TILE_HOST
GroupedGemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
ck_tile
::
index_t
M_
,
ck_tile
::
index_t
N_
,
ck_tile
::
index_t
K_
,
ck_tile
::
index_t
stride_A_
,
ck_tile
::
index_t
stride_B_
,
ck_tile
::
index_t
stride_C_
)
:
GemmHostArgs
(
a_ptr_
,
b_ptr_
,
c_ptr_
,
KBatch
,
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
)
{
}
private:
static
constexpr
index_t
KBatch
=
1
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
GroupedGemmKernel
struct
GroupedGemmKernel
:
public
GemmKernel
<
TilePartitioner_
,
GemmPipeline_
,
EpiloguePipeline_
>
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmPipeline
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
CLayout
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmPipeline
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
CLayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
OffsetTile1DPartitioner
=
OffsettedTile1DPartitioner
<
TilePartitioner
>
;
using
Base
=
GemmKernel
<
TilePartitioner_
,
GemmPipeline_
,
EpiloguePipeline_
>
;
using
GemmKernelArgs
=
typename
Base
::
GemmKernelArgs
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
static
constexpr
index_t
KBatch
=
1
;
struct
GemmTransKernelArg
{
G
roupedGemmHost
Args
group_karg
;
G
emmKernel
Args
group_karg
;
ck_tile
::
index_t
block_start
;
ck_tile
::
index_t
block_end
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
(
G
roupedGemmHost
Args
&&
karg
,
index_t
bl_start
,
index_t
bl_end
)
GemmTransKernelArg
(
G
emmKernel
Args
&&
karg
,
index_t
bl_start
,
index_t
bl_end
)
:
group_karg
{
karg
},
block_start
{
bl_start
},
block_end
{
bl_end
}
{
}
};
__host__
static
size_t
GetWorkSpaceSize
(
const
std
::
vector
<
GroupedGemmHostArgs
>&
gemm_descs
)
__host__
static
auto
GetWorkSpaceSize
(
const
std
::
vector
<
GroupedGemmHostArgs
>&
gemm_descs
)
->
std
::
size_t
{
return
gemm_descs
.
size
()
*
sizeof
(
GemmTransKernelArg
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
using
Hargs
=
GroupedGemmHostArgs
;
__host__
static
constexpr
auto
BlockSize
()
->
dim3
{
return
dim3
(
KernelBlockSize
);
}
__host__
static
constexpr
auto
GridSize
(
const
std
::
vector
<
Ha
rgs
>&
gemm_descs
)
__host__
static
constexpr
auto
GridSize
(
const
std
::
vector
<
GroupedGemmHostA
rgs
>&
gemm_descs
)
{
index_t
grid_size
=
0
;
for
(
const
auto
&
it_desc
:
gemm_descs
)
...
...
@@ -77,7 +84,8 @@ struct GroupedGemmKernel
return
dim3
(
grid_size
,
1
,
1
);
}
CK_TILE_HOST
static
auto
MakeKargs
(
const
std
::
vector
<
Hargs
>&
gemm_descs
)
CK_TILE_HOST
static
auto
MakeKargs
(
const
std
::
vector
<
GroupedGemmHostArgs
>&
gemm_descs
)
->
std
::
vector
<
GemmTransKernelArg
>
{
std
::
vector
<
GemmTransKernelArg
>
gemm_kernel_args_
;
index_t
group_count
=
ck_tile
::
type_convert
<
ck_tile
::
index_t
>
(
gemm_descs
.
size
());
...
...
@@ -100,22 +108,23 @@ struct GroupedGemmKernel
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C
;
const
auto
dim3
=
TilePartitioner
::
GridSize
(
M
,
N
);
const
index_t
grid_size_grp
=
dim3
.
x
*
1
*
1
;
const
index_t
grid_size_grp
=
dim3
.
x
;
const
index_t
block_start
=
grid_size
;
const
index_t
block_end
=
grid_size
+
grid_size_grp
;
grid_size
+=
grid_size_grp
;
auto
karg
=
GroupedGemmHostArgs
{
type_convert
<
const
ADataType
*>
(
gemm_descs
[
i
].
a_ptr
),
type_convert
<
const
BDataType
*>
(
gemm_descs
[
i
].
b_ptr
),
type_convert
<
CDataType
*>
(
gemm_descs
[
i
].
c_ptr
),
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
};
auto
karg
=
GemmKernelArgs
{
type_convert
<
const
ADataType
*>
(
gemm_descs
[
i
].
a_ptr
),
type_convert
<
const
BDataType
*>
(
gemm_descs
[
i
].
b_ptr
),
type_convert
<
CDataType
*>
(
gemm_descs
[
i
].
c_ptr
),
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
,
KBatch
};
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
block_start
,
block_end
);
}
...
...
@@ -123,162 +132,34 @@ struct GroupedGemmKernel
return
gemm_kernel_args_
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemSize
()
->
index_t
{
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
Run
(
const
Hargs
&
kargs
,
const
index_t
block_start
)
const
CK_TILE_DEVICE
void
Run
(
const
GemmTransKernelArg
&
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}(
block_start
,
kargs
.
N
);
// options
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
const
auto
[
iM
,
iN
]
=
OffsetTile1DPartitioner
::
GetOffsetedTileIndex
(
kargs
.
block_start
,
kargs
.
group_karg
.
N
);
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
1
>
{});
}
}();
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
iM
*
TilePartitioner
::
MPerBlock
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
iN
*
TilePartitioner
::
NPerBlock
);
auto
a_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
// clang-format on
const
typename
Base
::
SplitKBatchOffset
splitk_batch_offset
(
kargs
.
group_karg
,
blockIdx
.
z
);
auto
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_m
,
0
});
auto
b_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
}
}();
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_n
,
0
});
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
group_karg
.
a_ptr
);
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
group_karg
.
b_ptr
);
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
group_karg
.
c_ptr
);
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole wokrgroup.
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSizeC
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
c_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
auto
CBlockWindow_pad
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CBlockWindow_pad
,
c_block_tile
);
this
->
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
.
group_karg
,
splitk_batch_offset
,
i_m
,
i_n
);
}
CK_TILE_DEVICE
void
operator
()(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
int
group_count
)
const
in
dex_
t
group_count
)
const
{
const
index_t
block_id
=
ck_tile
::
get_block_1d_id
();
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmTransKernelArg
*>
(
...
...
@@ -286,7 +167,7 @@ struct GroupedGemmKernel
index_t
left
=
0
;
index_t
right
=
group_count
;
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
index_t
group_id
=
index_t
((
left
+
right
)
>>
1
);
while
((
!
(
block_id
>=
gemm_desc_ptr
[
group_id
].
block_start
&&
block_id
<
gemm_desc_ptr
[
group_id
].
block_end
))
&&
...
...
@@ -300,10 +181,10 @@ struct GroupedGemmKernel
{
left
=
group_id
;
}
group_id
=
index_t
((
left
+
right
)
/
2
);
group_id
=
index_t
((
left
+
right
)
>>
1
);
}
Run
(
gemm_desc_ptr
[
group_id
]
.
group_karg
,
gemm_desc_ptr
[
group_id
].
block_start
);
Run
(
gemm_desc_ptr
[
group_id
]);
}
};
...
...
include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp
View file @
b3054fea
...
...
@@ -101,6 +101,7 @@ struct MoeSmoothquant
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
int8_t
>
{
static
constexpr
const
char
*
name
=
"i8"
;
};
// clang-format on
// in byte
...
...
@@ -118,7 +119,7 @@ struct MoeSmoothquant
#define _SS_ std::string
#define _TS_ std::to_string
return
_SS_
(
"moe_smoothquant_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
return
_SS_
(
"moe_smoothquant_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
_SS_
(
t2s
<
QYDataType
>::
name
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
...
...
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp
View file @
b3054fea
...
...
@@ -113,7 +113,7 @@ struct SmoothquantPipelineOnePass
sweep_tile
(
qy
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
auto
qy_
=
y
[
idx
]
/
yscale
[
i_idx
];
qy
(
idx
)
=
saturates
<
QYDataType
>
{}(
qy_
);
qy
(
idx
)
=
type_convert
<
QYDataType
>
(
saturates
<
QYDataType
>
{}(
qy_
)
)
;
});
store_tile
(
qy_window
,
qy
);
}
...
...
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp
View file @
b3054fea
...
...
@@ -136,7 +136,7 @@ struct SmoothquantPipelineTwoPass
sweep_tile
(
qy
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
auto
qy_
=
y
[
idx
]
/
yscale
[
i_idx
];
qy
(
idx
)
=
saturates
<
QYDataType
>
{}(
qy_
);
qy
(
idx
)
=
type_convert
<
QYDataType
>
(
saturates
<
QYDataType
>
{}(
qy_
)
)
;
});
store_tile
(
qy_window
,
qy
);
...
...
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
View file @
b3054fea
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
...
...
@@ -61,7 +61,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile
2D
Partitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
...
...
@@ -73,8 +73,8 @@ class TestCkTileBatchedGemm : public ::testing::Test
kOutputRank
,
1
,
0
,
TilePartitioner
::
k
M
,
TilePartitioner
::
k
N
>>
,
TilePartitioner
::
M
PerBlock
,
TilePartitioner
::
N
PerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
...
...
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
b3054fea
...
...
@@ -62,7 +62,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile
2D
Partitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment