Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f3ff55b6
Commit
f3ff55b6
authored
Dec 17, 2024
by
illsilin
Browse files
merge from public repo
parents
5e93fa9e
689a5ae4
Changes
54
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2774 additions
and
1164 deletions
+2774
-1164
example/ck_tile/12_smoothquant/example_smoothquant.cpp
example/ck_tile/12_smoothquant/example_smoothquant.cpp
+26
-18
example/ck_tile/12_smoothquant/smoothquant.cpp
example/ck_tile/12_smoothquant/smoothquant.cpp
+26
-18
include/ck/config.h.in
include/ck/config.h.in
+16
-0
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+34
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp
...nsor_operation/gpu/device/device_batched_gemm_multi_d.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
...e/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
+29
-16
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+8
-8
include/ck/utility/amd_ck_fp8.hpp
include/ck/utility/amd_ck_fp8.hpp
+18
-6
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+1
-1
include/ck_tile/README.md
include/ck_tile/README.md
+3
-0
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/ops/flatmm.hpp
include/ck_tile/ops/flatmm.hpp
+1
-0
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp
.../flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp
+510
-0
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
.../block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
+794
-589
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc
...ck/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc
+769
-0
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
...tmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
+504
-503
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
...sed_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
+27
-2
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
...e/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
+3
-1
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-1
No files found.
example/ck_tile/12_smoothquant/example_smoothquant.cpp
View file @
f3ff55b6
...
...
@@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[])
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to n"
)
.
insert
(
"x_stride"
,
"-1"
,
"input stride per row, if -1 then equal to n"
)
.
insert
(
"y_stride"
,
"-1"
,
"output stride per row, if -1 then equal to n"
)
.
insert
(
"e"
,
"1e-5"
,
"epsilon"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
...
...
@@ -49,11 +50,14 @@ auto create_args(int argc, char* argv[])
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
if
(
stride
<
0
)
stride
=
n
;
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
x_stride
=
arg_parser
.
get_int
(
"x_stride"
);
if
(
x_stride
<
0
)
x_stride
=
n
;
ck_tile
::
index_t
y_stride
=
arg_parser
.
get_int
(
"y_stride"
);
if
(
y_stride
<
0
)
y_stride
=
n
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
...
...
@@ -68,14 +72,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
ComputeDataType
=
float
;
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
x_
stride
,
1
});
ck_tile
::
HostTensor
<
XScaleDataType
>
xscale_host
({
n
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_ref
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_dev
({
m
},
{
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_ref
({
m
,
n
},
{
y_
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
y_
stride
,
1
});
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
1e-3
,
.5
f
}(
xscale_host
);
...
...
@@ -116,7 +120,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf
.
GetDeviceBuffer
(),
m
,
n
,
stride
};
x_stride
,
y_stride
};
auto
kargs
=
Kernel
::
MakeKargs
(
args
);
...
...
@@ -133,7 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
{
using
YDataType
=
ComputeDataType
;
ck_tile
::
HostTensor
<
ComputeDataType
>
y_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
ComputeDataType
>
y_host
({
m
,
n
},
{
y_
stride
,
1
});
// smooth outlier
{
auto
f
=
[
&
](
auto
n_
)
{
...
...
@@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf
.
FromDevice
(
qy_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
QYDataType
>
();
if
(
stride
==
n
)
if
(
y_
stride
==
n
)
{
pass
=
ck_tile
::
check_err
(
qy_host_dev
,
qy_host_ref
,
...
...
@@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
{
std
::
vector
<
QYDataType
>
qy_host_dev_row
(
qy_host_dev
.
begin
()
+
i_r
*
stride
,
qy_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_ref_row
(
qy_host_ref
.
begin
()
+
i_r
*
stride
,
qy_host_ref
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_dev_row
(
qy_host_dev
.
begin
()
+
i_r
*
y_stride
,
qy_host_dev
.
begin
()
+
i_r
*
y_stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_ref_row
(
qy_host_ref
.
begin
()
+
i_r
*
y_stride
,
qy_host_ref
.
begin
()
+
i_r
*
y_stride
+
n
);
pass
&=
ck_tile
::
check_err
(
qy_host_dev_row
,
qy_host_ref_row
,
std
::
string
(
"qy["
)
+
std
::
to_string
(
i_r
)
+
...
...
@@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
std
::
cout
<<
"["
<<
data_type
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", x_stride:"
<<
x_stride
<<
", y_stride:"
<<
y_stride
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
return
pass
;
...
...
example/ck_tile/12_smoothquant/smoothquant.cpp
View file @
f3ff55b6
...
...
@@ -33,7 +33,8 @@ auto create_args(int argc, char* argv[])
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to n"
)
.
insert
(
"x_stride"
,
"-1"
,
"input stride per row, if -1 then equal to n"
)
.
insert
(
"y_stride"
,
"-1"
,
"output stride per row, if -1 then equal to n"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
...
...
@@ -47,18 +48,21 @@ auto create_args(int argc, char* argv[])
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
if
(
stride
<
0
)
stride
=
n
;
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
x_stride
=
arg_parser
.
get_int
(
"x_stride"
);
if
(
x_stride
<
0
)
x_stride
=
n
;
ck_tile
::
index_t
y_stride
=
arg_parser
.
get_int
(
"y_stride"
);
if
(
y_stride
<
0
)
y_stride
=
n
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
assert
(
stride
>=
n
);
assert
(
x_
stride
>=
n
);
using
TypeConfig
=
SmoothquantTypeConfig
<
DataType
>
;
...
...
@@ -69,14 +73,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
x_
stride
,
1
});
ck_tile
::
HostTensor
<
XScaleDataType
>
xscale_host
({
n
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_ref
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_dev
({
m
},
{
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_ref
({
m
,
n
},
{
y_
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
y_
stride
,
1
});
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
1e-3
,
.5
f
}(
xscale_host
);
...
...
@@ -90,7 +94,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
xscale_buf
.
ToDevice
(
xscale_host
.
data
());
std
::
cout
<<
"["
<<
data_type
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", x_stride:"
<<
x_stride
<<
", y_stride:"
<<
y_stride
<<
std
::
flush
;
smoothquant_traits
traits
{
data_type
};
...
...
@@ -100,7 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf
.
GetDeviceBuffer
(),
m
,
n
,
stride
};
x_stride
,
y_stride
};
float
ave_time
=
smoothquant
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
...
...
@@ -116,7 +122,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
{
using
YDataType
=
ComputeDataType
;
ck_tile
::
HostTensor
<
ComputeDataType
>
y_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
ComputeDataType
>
y_host
({
m
,
n
},
{
y_
stride
,
1
});
// smooth outlier
{
auto
f
=
[
&
](
auto
n_
)
{
...
...
@@ -166,7 +172,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf
.
FromDevice
(
qy_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
QYDataType
>
();
if
(
stride
==
n
)
if
(
y_
stride
==
n
)
{
pass
=
ck_tile
::
check_err
(
qy_host_dev
,
qy_host_ref
,
...
...
@@ -178,10 +184,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
{
std
::
vector
<
QYDataType
>
qy_host_dev_row
(
qy_host_dev
.
begin
()
+
i_r
*
stride
,
qy_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_ref_row
(
qy_host_ref
.
begin
()
+
i_r
*
stride
,
qy_host_ref
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_dev_row
(
qy_host_dev
.
begin
()
+
i_r
*
y_stride
,
qy_host_dev
.
begin
()
+
i_r
*
y_stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_ref_row
(
qy_host_ref
.
begin
()
+
i_r
*
y_stride
,
qy_host_ref
.
begin
()
+
i_r
*
y_stride
+
n
);
pass
&=
ck_tile
::
check_err
(
qy_host_dev_row
,
qy_host_ref_row
,
std
::
string
(
"qy["
)
+
std
::
to_string
(
i_r
)
+
...
...
include/ck/config.h.in
View file @
f3ff55b6
...
...
@@ -111,6 +111,22 @@
#cmakedefine CK_USE_WMMA @CK_USE_WMMA@
#endif
#ifndef CK_USE_GFX94
#cmakedefine CK_USE_GFX94 @CK_USE_GFX94@
#endif
#ifndef DCK_USE_OCP_FP8
#cmakedefine DCK_USE_OCP_FP8 @DCK_USE_OCP_FP8@
#endif
#ifndef CK_USE_FNUZ_FP8
#cmakedefine CK_USE_FNUZ_FP8 @CK_USE_FNUZ_FP8@
#endif
#ifndef CK_USE_FP8_ON_UNSUPPORTED_ARCH
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#endif
// clang-format on
#endif // CK_CONFIG_H_IN
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
f3ff55b6
...
...
@@ -5,6 +5,8 @@
#include <string>
#include <sstream>
#include <regex>
#include <optional>
#include "ck/stream_config.hpp"
...
...
@@ -12,6 +14,34 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
#define GET_OBJECT_NAME_IMLP \
std::optional<std::string> GetObjectName() const override \
{ \
std::string str = __PRETTY_FUNCTION__; \
static std::regex obj_name_expr{"<std::string> (.*)::GetObjectName"}; \
std::smatch match; \
if(!std::regex_search(str, match, obj_name_expr)) \
{ \
return str; \
} \
return std::string(match[1]) + ';'; \
}
#define GET_TEMPLATE_INFO_IMPL \
std::optional<std::string> GetTemplateInfo() const override \
{ \
std::string str = __PRETTY_FUNCTION__; \
static std::regex template_expr{"\\[(.*)\\]"}; \
std::smatch match; \
if(!std::regex_search(str, match, template_expr)) \
{ \
return std::nullopt; \
} \
return std::string(match[1]); \
}
#define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
struct
BaseArgument
{
BaseArgument
()
=
default
;
...
...
@@ -48,6 +78,10 @@ struct BaseOperator
virtual
std
::
string
GetTypeIdName
()
const
{
return
typeid
(
*
this
).
name
();
}
virtual
std
::
optional
<
std
::
string
>
GetObjectName
()
const
{
return
std
::
nullopt
;
}
virtual
std
::
optional
<
std
::
string
>
GetTemplateInfo
()
const
{
return
std
::
nullopt
;
}
virtual
std
::
string
GetTypeIdHashCode
()
const
{
std
::
ostringstream
oss
;
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp
View file @
f3ff55b6
...
...
@@ -89,7 +89,8 @@ struct DeviceBatchedGemmV2MultiD : public BaseOperator
index_t
BatchStrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
CDEElementwiseOperation
cde_element_op
,
index_t
KBatch
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
View file @
f3ff55b6
...
...
@@ -41,12 +41,15 @@ __global__ void
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
g_idx
=
blockIdx
.
z
%
karg
.
Batch
;
const
index_t
k_idx
=
blockIdx
.
z
/
karg
.
Batch
;
const
auto
a_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
auto
b_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
auto
ds_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
const
auto
c_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
k_idx
);
// populate pointer, desc for Ds
static_for
<
0
,
GridwiseGemm
::
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
// D pointer
...
...
@@ -54,8 +57,8 @@ __global__ void
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
,
karg
.
p_b_grid
+
b_batch_offset
,
karg
.
p_a_grid
+
a_batch_offset
+
splitk_batch_offset
.
a_k_split_offset
,
karg
.
p_b_grid
+
b_batch_offset
+
splitk_batch_offset
.
b_k_split_offset
,
karg
.
p_ds_grid
,
karg
.
p_c_grid
+
c_batch_offset
,
p_shared
,
...
...
@@ -87,12 +90,15 @@ __global__ void
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
g_idx
=
blockIdx
.
z
%
karg
.
Batch
;
const
index_t
k_idx
=
blockIdx
.
z
/
karg
.
Batch
;
const
auto
a_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
auto
b_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
auto
ds_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
const
auto
c_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
k_idx
);
// populate pointer, desc for Ds
static_for
<
0
,
GridwiseGemm
::
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
// D pointer
...
...
@@ -100,8 +106,8 @@ __global__ void
});
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
,
karg
.
p_b_grid
+
b_batch_offset
,
karg
.
p_a_grid
+
a_batch_offset
+
splitk_batch_offset
.
a_k_split_offset
,
karg
.
p_b_grid
+
b_batch_offset
+
splitk_batch_offset
.
b_k_split_offset
,
karg
.
p_ds_grid
,
karg
.
p_c_grid
+
c_batch_offset
,
p_shared_0
,
...
...
@@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t
Batch_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CElementwiseOperation
c_element_op_
)
CElementwiseOperation
c_element_op_
,
index_t
KBatch_
)
:
GridwiseGemm
::
Argument
{
p_a_grid_
,
p_b_grid_
,
p_ds_grid_
,
...
...
@@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
StrideB_
,
StrideDs_
,
StrideE_
,
1
,
KBatch_
,
a_element_op_
,
b_element_op_
,
c_element_op_
},
...
...
@@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
arg
.
Print
();
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
)
||
arg
.
KBatch
>
1
)
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
,
arg
.
Batch
);
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
,
arg
.
Batch
*
arg
.
KBatch
);
float
ave_time
=
0
;
...
...
@@ -387,10 +395,11 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
rotating_mem
.
Next
();
// clear c mem
if
(
arg_
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
arg
.
Batch
*
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
...
...
@@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t
BatchStrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
,
index_t
KBatch
=
1
)
{
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
...
@@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch
,
a_element_op
,
b_element_op
,
c_element_op
};
c_element_op
,
KBatch
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t
BatchStrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
override
CElementwiseOperation
c_element_op
,
index_t
KBatch
=
1
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
...
@@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch
,
a_element_op
,
b_element_op
,
c_element_op
);
c_element_op
,
KBatch
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
View file @
f3ff55b6
...
...
@@ -729,6 +729,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return
str
.
str
();
}
REGISTER_EXTRA_PRINTING_METHODS
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
f3ff55b6
...
...
@@ -41,7 +41,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
blockIdx
.
z
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
...
...
@@ -76,7 +76,7 @@ __global__ void
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
blockIdx
.
z
);
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
...
...
@@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
struct
SplitKBatchOffset
{
__device__
SplitKBatchOffset
(
Argument
&
karg
)
__device__
SplitKBatchOffset
(
Argument
&
karg
,
index_t
k_id
)
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
a_k_split_offset
=
k_id
*
karg
.
KRead
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideA
;
a_k_split_offset
=
k_id
*
karg
.
KRead
*
karg
.
StrideA
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideB
;
b_k_split_offset
=
k_id
*
karg
.
KRead
*
karg
.
StrideB
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
b_k_split_offset
=
k_id
*
karg
.
KRead
;
}
if
(
blockIdx
.
z
<
static_cast
<
uint32_t
>
(
karg
.
KBatch
-
1
)
)
if
(
k_id
<
karg
.
KBatch
-
1
)
{
karg
.
K
=
karg
.
KRead
;
}
...
...
include/ck/utility/amd_ck_fp8.hpp
View file @
f3ff55b6
...
...
@@ -18,6 +18,20 @@
#define CK_USE_OCP_FP8 0
#endif
namespace
{
// https://en.cppreference.com/w/cpp/types/conditional
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
}
// namespace
namespace
ck
{
using
f8_fnuz_t
=
_BitInt
(
8
);
...
...
@@ -191,11 +205,10 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
}
}
typename
__hip_internal
::
conditional
<
typename
conditional
<
sizeof
(
T
)
==
2
,
unsigned
short
int
,
typename
__hip_internal
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
retval
;
typename
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
retval
;
if
constexpr
(
we
==
5
&&
is_half
&&
!
is_fnuz
)
{
...
...
@@ -538,11 +551,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
constexpr
int
mfmt
=
(
sizeof
(
T
)
==
8
)
?
52
:
((
sizeof
(
T
)
==
4
)
?
23
:
10
);
using
T_bitwise
=
typename
__hip_internal
::
conditional
<
using
T_bitwise
=
typename
conditional
<
sizeof
(
T
)
==
2
,
unsigned
short
int
,
typename
__hip_internal
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
;
typename
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
;
T_bitwise
x_bitwise
=
bit_cast
<
T_bitwise
>
(
_x
);
unsigned
long
long
x
{
x_bitwise
};
...
...
include/ck/utility/math_v2.hpp
View file @
f3ff55b6
...
...
@@ -611,7 +611,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x)
template
<
>
inline
__device__
half_t
neg
<
half_t
>
(
half_t
x
)
{
return
__hneg
(
x
);
return
__hneg
(
static_cast
<
__half
>
(
x
)
);
};
template
<
typename
T
>
...
...
include/ck_tile/README.md
View file @
f3ff55b6
...
...
@@ -45,5 +45,8 @@ our implementation of different device operators.
**[ops/epilogue]**
epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues.
**[ref]**
reference implementation of cpu or gpu. This folder is supposed to include a specific header on demand.
## examples
currently we put all ck_tile related example under
[
/example/ck_tile
](
/example/ck_tile/
)
folder. Please check each example's subfolder.
include/ck_tile/core.hpp
View file @
f3ff55b6
...
...
@@ -54,6 +54,7 @@
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
...
...
include/ck_tile/ops/flatmm.hpp
View file @
f3ff55b6
...
...
@@ -5,6 +5,7 @@
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp
0 → 100644
View file @
f3ff55b6
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
View file @
f3ff55b6
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc
0 → 100644
View file @
f3ff55b6
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
View file @
f3ff55b6
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
View file @
f3ff55b6
...
...
@@ -810,21 +810,46 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetUK_1
()
{
using
S_
=
typename
Problem
::
BlockShape
;
using
T_
=
typename
Problem
::
Traits
;
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
TopkWeightDataType
,
float
>
&&
S_
::
Block_M1
==
32
&&
S_
::
Block_N1
==
128
&&
S_
::
Block_K1
==
512
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
)
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
&&
T_
::
PipeInterleave
==
false
)
{
return
FlatmmSn_32x128x512_1x4x1_16x16x32_BF16
{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
TopkWeightDataType
,
float
>
&&
S_
::
Block_M1
==
32
&&
S_
::
Block_N1
==
128
&&
S_
::
Block_K1
==
512
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
)
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
&&
T_
::
PipeInterleave
==
false
)
{
return
FlatmmSn_32x128x512_1x4x1_16x16x32_FP16
{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
TopkWeightDataType
,
float
>
&&
S_
::
Block_M1
==
32
&&
S_
::
Block_N1
==
128
&&
S_
::
Block_K1
==
512
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
&&
T_
::
PipeInterleave
==
true
)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return
FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
TopkWeightDataType
,
float
>
&&
S_
::
Block_M1
==
32
&&
S_
::
Block_N1
==
128
&&
S_
::
Block_K1
==
512
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
&&
T_
::
PipeInterleave
==
true
)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return
FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl
{};
}
}
};
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
View file @
f3ff55b6
...
...
@@ -22,7 +22,8 @@ template <bool IsGateOnly_,
FusedMoeGemmWeightPermuteEnum
PermuteEnum_
=
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
,
bool
PadHiddenSize_
=
false
,
bool
PadIntermediateSize_
=
false
>
bool
PadIntermediateSize_
=
false
,
bool
PipeInterleave_
=
true
>
struct
FusedMoeGemmTraits
{
// Gate+Up or Gate only
...
...
@@ -32,6 +33,7 @@ struct FusedMoeGemmTraits
static
constexpr
FusedMoeGemmWeightPermuteEnum
PermuteEnum
=
PermuteEnum_
;
static
constexpr
bool
PadHiddenSize
=
PadHiddenSize_
;
static
constexpr
bool
PadIntermediateSize
=
PadIntermediateSize_
;
static
constexpr
bool
PipeInterleave
=
PipeInterleave_
;
};
// Note: this need to be a bit mask
...
...
include/ck_tile/ops/gemm.hpp
View file @
f3ff55b6
...
...
@@ -23,10 +23,10 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment