Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d20c20a6
Commit
d20c20a6
authored
Oct 16, 2024
by
Mirza Halilcevic
Browse files
Merge remote-tracking branch 'upstream/develop' into gemm_elementwise_gemm
parents
250a89f3
10158b0f
Changes
95
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
167 additions
and
62 deletions
+167
-62
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+22
-1
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+5
-1
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+20
-3
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+5
-1
example/ck_tile/02_layernorm2d/README.md
example/ck_tile/02_layernorm2d/README.md
+2
-1
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+3
-1
example/ck_tile/03_gemm/README.md
example/ck_tile/03_gemm/README.md
+14
-6
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+44
-16
example/ck_tile/04_img2col/README.md
example/ck_tile/04_img2col/README.md
+2
-1
include/ck/config.h.in
include/ck/config.h.in
+0
-7
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+6
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
+5
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
+10
-8
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
+5
-4
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+17
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
...evice/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
...device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+1
-1
No files found.
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
d20c20a6
...
@@ -85,6 +85,9 @@ auto create_args(int argc, char* argv[])
...
@@ -85,6 +85,9 @@ auto create_args(int argc, char* argv[])
.
insert
(
"p_drop"
,
"0"
,
"0~1 probability of dropout"
)
.
insert
(
"p_drop"
,
"0"
,
"0~1 probability of dropout"
)
.
insert
(
"drop_seed"
,
"1"
,
"seed for random number generator"
)
.
insert
(
"drop_seed"
,
"1"
,
"seed for random number generator"
)
.
insert
(
"drop_offset"
,
"0"
,
"offset for random number generator"
)
.
insert
(
"drop_offset"
,
"0"
,
"offset for random number generator"
)
.
insert
(
"drop_prefs"
,
"0"
,
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"20"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"repeat"
,
"20"
,
"number of iterations to benchmark the kernel"
)
...
@@ -158,6 +161,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -158,6 +161,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
float
p_drop
=
arg_parser
.
get_float
(
"p_drop"
);
float
p_drop
=
arg_parser
.
get_float
(
"p_drop"
);
uint64_t
drop_seed
=
arg_parser
.
get_uint64
(
"drop_seed"
);
uint64_t
drop_seed
=
arg_parser
.
get_uint64
(
"drop_seed"
);
uint64_t
drop_offset
=
arg_parser
.
get_uint64
(
"drop_offset"
);
uint64_t
drop_offset
=
arg_parser
.
get_uint64
(
"drop_offset"
);
bool
drop_prefs
=
arg_parser
.
get_bool
(
"drop_prefs"
);
if
(
use_dbias
&&
bias
.
type
!=
bias_enum
::
elementwise_bias
)
if
(
use_dbias
&&
bias
.
type
!=
bias_enum
::
elementwise_bias
)
{
{
std
::
cerr
<<
"dbias only exists when bias type is elementwise"
<<
std
::
endl
;
std
::
cerr
<<
"dbias only exists when bias type is elementwise"
<<
std
::
endl
;
...
@@ -381,6 +386,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -381,6 +386,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
dbias_buf
(
dbias_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dbias_buf
(
dbias_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
seqstart_q
(
seqstart_q_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_q
(
seqstart_q_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_k
(
seqstart_k_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_k
(
seqstart_k_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
drop_seed_buf
(
drop_prefs
?
sizeof
(
uint64_t
)
:
0
);
ck_tile
::
DeviceMem
drop_offset_buf
(
drop_prefs
?
sizeof
(
uint64_t
)
:
0
);
ck_tile
::
DeviceMem
alibi_slope_buf
(
alibi_slope_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
alibi_slope_buf
(
alibi_slope_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dq_acc_buf
(
dq_acc_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dq_acc_buf
(
dq_acc_host
.
get_element_space_size_in_bytes
());
...
@@ -391,6 +398,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -391,6 +398,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
do_buf
.
ToDevice
(
do_host
.
data
());
do_buf
.
ToDevice
(
do_host
.
data
());
seqstart_q
.
ToDevice
(
seqstart_q_host
.
data
());
seqstart_q
.
ToDevice
(
seqstart_q_host
.
data
());
seqstart_k
.
ToDevice
(
seqstart_k_host
.
data
());
seqstart_k
.
ToDevice
(
seqstart_k_host
.
data
());
drop_seed_buf
.
ToDevice
(
drop_prefs
?
&
drop_seed
:
nullptr
);
drop_offset_buf
.
ToDevice
(
drop_prefs
?
&
drop_offset
:
nullptr
);
alibi_slope_buf
.
ToDevice
(
alibi_slope_host
.
data
());
alibi_slope_buf
.
ToDevice
(
alibi_slope_host
.
data
());
// clang-format off
// clang-format off
...
@@ -472,6 +481,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -472,6 +481,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
split_stride_dq_acc
=
const
ck_tile
::
index_t
split_stride_dq_acc
=
(
shape_batch
*
nhead
*
shape_seqlen_q
*
hdim_q
);
(
shape_batch
*
nhead
*
shape_seqlen_q
*
hdim_q
);
const
auto
drop_seed_offset
=
[
&
]()
->
decltype
(
fmha_bwd_args
::
drop_seed_offset
)
{
if
(
drop_prefs
)
{
return
std
::
make_pair
(
drop_seed_buf
.
GetDeviceBuffer
(),
drop_offset_buf
.
GetDeviceBuffer
());
}
else
{
return
std
::
make_pair
(
drop_seed
,
drop_offset
);
}
}();
return
fmha_bwd_args
{
q_buf
.
GetDeviceBuffer
(),
return
fmha_bwd_args
{
q_buf
.
GetDeviceBuffer
(),
k_buf
.
GetDeviceBuffer
(),
k_buf
.
GetDeviceBuffer
(),
v_buf
.
GetDeviceBuffer
(),
v_buf
.
GetDeviceBuffer
(),
...
@@ -545,7 +566,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -545,7 +566,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
p_drop
,
p_drop
,
p_undrop
,
p_undrop
,
{
drop_seed
,
drop
_offset
}
}
;
drop_seed_offset
};
}();
}();
float
ave_time
=
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config
);
float
ave_time
=
fmha_bwd
(
fmha_traits
,
fmha_args
,
stream_config
);
...
...
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
d20c20a6
...
@@ -9,7 +9,10 @@
...
@@ -9,7 +9,10 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include "mask.hpp"
#include "bias.hpp"
#include "bias.hpp"
#include <type_traits>
#include <type_traits>
#include <utility>
#include <variant>
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
FmhaBwdTypeConfig
;
struct
FmhaBwdTypeConfig
;
...
@@ -135,7 +138,8 @@ struct fmha_bwd_args
...
@@ -135,7 +138,8 @@ struct fmha_bwd_args
ck_tile
::
index_t
mask_type
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
float
p_drop
;
float
p_undrop
;
float
p_undrop
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
;
};
};
template
<
typename
FmhaBwdDQDKDVKernel
>
template
<
typename
FmhaBwdDQDKDVKernel
>
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
d20c20a6
...
@@ -122,6 +122,9 @@ auto create_args(int argc, char* argv[])
...
@@ -122,6 +122,9 @@ auto create_args(int argc, char* argv[])
.
insert
(
"p_drop"
,
"0"
,
"0~1 probability of dropout"
)
.
insert
(
"p_drop"
,
"0"
,
"0~1 probability of dropout"
)
.
insert
(
"drop_seed"
,
"1"
,
"seed for random number generator"
)
.
insert
(
"drop_seed"
,
"1"
,
"seed for random number generator"
)
.
insert
(
"drop_offset"
,
"0"
,
"offset for random number generator"
)
.
insert
(
"drop_offset"
,
"0"
,
"offset for random number generator"
)
.
insert
(
"drop_prefs"
,
"0"
,
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
.
insert
(
"rotary_dim"
,
"0"
,
"RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all"
)
"rotary_dim"
,
"0"
,
"RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all"
)
...
@@ -442,6 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -442,6 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
float
p_drop
=
arg_parser
.
get_float
(
"p_drop"
);
float
p_drop
=
arg_parser
.
get_float
(
"p_drop"
);
uint64_t
drop_seed
=
arg_parser
.
get_uint64
(
"drop_seed"
);
uint64_t
drop_seed
=
arg_parser
.
get_uint64
(
"drop_seed"
);
uint64_t
drop_offset
=
arg_parser
.
get_uint64
(
"drop_offset"
);
uint64_t
drop_offset
=
arg_parser
.
get_uint64
(
"drop_offset"
);
bool
drop_prefs
=
arg_parser
.
get_bool
(
"drop_prefs"
);
if
(
p_drop
<
0.0
f
||
p_drop
>
1.0
f
)
if
(
p_drop
<
0.0
f
||
p_drop
>
1.0
f
)
{
{
std
::
cerr
<<
"The value of p_drop should be 0~1"
<<
std
::
endl
;
std
::
cerr
<<
"The value of p_drop should be 0~1"
<<
std
::
endl
;
...
@@ -756,6 +761,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -756,6 +761,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
need_append_kvcache
?
cache_seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
need_append_kvcache
?
cache_seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
ck_tile
::
DeviceMem
rotary_cos_buf
(
rotary_cos_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
rotary_cos_buf
(
rotary_cos_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
rotary_sin_buf
(
rotary_sin_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
rotary_sin_buf
(
rotary_sin_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
drop_seed_buf
(
drop_prefs
?
sizeof
(
uint64_t
)
:
0
);
ck_tile
::
DeviceMem
drop_offset_buf
(
drop_prefs
?
sizeof
(
uint64_t
)
:
0
);
ck_tile
::
DeviceMem
randval_buf
(
randval_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
randval_buf
(
randval_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
alibi_slope_buf
(
alibi_slope_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
alibi_slope_buf
(
alibi_slope_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
block_table_buf
(
block_table_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
block_table_buf
(
block_table_host
.
get_element_space_size_in_bytes
());
...
@@ -774,6 +781,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -774,6 +781,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
cache_seqlen_k_buf
.
ToDevice
(
need_append_kvcache
?
cache_seqlen_ks
.
data
()
:
nullptr
);
cache_seqlen_k_buf
.
ToDevice
(
need_append_kvcache
?
cache_seqlen_ks
.
data
()
:
nullptr
);
rotary_cos_buf
.
ToDevice
(
rotary_cos_host
.
data
());
rotary_cos_buf
.
ToDevice
(
rotary_cos_host
.
data
());
rotary_sin_buf
.
ToDevice
(
rotary_sin_host
.
data
());
rotary_sin_buf
.
ToDevice
(
rotary_sin_host
.
data
());
drop_seed_buf
.
ToDevice
(
drop_prefs
?
&
drop_seed
:
nullptr
);
drop_offset_buf
.
ToDevice
(
drop_prefs
?
&
drop_offset
:
nullptr
);
alibi_slope_buf
.
ToDevice
(
alibi_slope_host
.
data
());
alibi_slope_buf
.
ToDevice
(
alibi_slope_host
.
data
());
block_table_buf
.
ToDevice
(
block_table_host
.
data
());
block_table_buf
.
ToDevice
(
block_table_host
.
data
());
cache_batch_idx_buf
.
ToDevice
(
cache_batch_idx_host
.
data
());
cache_batch_idx_buf
.
ToDevice
(
cache_batch_idx_host
.
data
());
...
@@ -1015,7 +1024,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -1015,7 +1024,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
args
.
p_drop
=
p_drop
;
args
.
p_drop
=
p_drop
;
args
.
s_randval
=
s_randval
;
args
.
s_randval
=
s_randval
;
args
.
drop_seed_offset
=
std
::
tie
(
drop_seed
,
drop_offset
);
if
(
drop_prefs
)
{
args
.
drop_seed_offset
=
std
::
make_pair
(
drop_seed_buf
.
GetDeviceBuffer
(),
drop_offset_buf
.
GetDeviceBuffer
());
}
else
{
args
.
drop_seed_offset
=
std
::
make_pair
(
drop_seed
,
drop_offset
);
}
}
}
else
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_splitkv_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
else
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_splitkv_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
{
{
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
d20c20a6
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
#include "rotary.hpp"
#include "rotary.hpp"
#include <type_traits>
#include <type_traits>
#include <utility>
#include <variant>
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
FmhaFwdTypeConfig
;
struct
FmhaFwdTypeConfig
;
...
@@ -144,7 +146,9 @@ struct fmha_fwd_args
...
@@ -144,7 +146,9 @@ struct fmha_fwd_args
float
p_drop
;
float
p_drop
;
bool
s_randval
;
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
;
};
};
struct
fmha_fwd_splitkv_args
struct
fmha_fwd_splitkv_args
...
...
example/ck_tile/02_layernorm2d/README.md
View file @
d20c20a6
...
@@ -6,7 +6,8 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm
...
@@ -6,7 +6,8 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm
```
```
# in the root of ck_tile
# in the root of ck_tile
mkdir build && cd build
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_layernorm2d_fwd -j
make tile_example_layernorm2d_fwd -j
```
```
This will result in an executable
`build/bin/tile_example_layernorm2d_fwd`
This will result in an executable
`build/bin/tile_example_layernorm2d_fwd`
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
View file @
d20c20a6
...
@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
...
@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
YDataType
,
YDataType
,
MeanDataType
,
MeanDataType
,
InvStdDataType
,
InvStdDataType
,
Shape
>
;
Shape
,
true
,
true
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
...
...
example/ck_tile/03_gemm/README.md
View file @
d20c20a6
...
@@ -6,7 +6,8 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
...
@@ -6,7 +6,8 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
```
```
# in the root of ck_tile
# in the root of ck_tile
mkdir build && cd build
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_gemm_basic -j
make tile_example_gemm_basic -j
```
```
This will result in an executable
`build/bin/tile_example_gemm_basic`
This will result in an executable
`build/bin/tile_example_gemm_basic`
...
@@ -14,10 +15,17 @@ This will result in an executable `build/bin/tile_example_gemm_basic`
...
@@ -14,10 +15,17 @@ This will result in an executable `build/bin/tile_example_gemm_basic`
## example
## example
```
```
args:
args:
-m m dimension (default:3328)
-b batch size (default:1)
-n m dimension (default:4096)
-m m dimension (default:1024)
-n n dimension (default:2048)
-k k dimension (default:64)
-k k dimension (default:64)
-e epsilon (default:1e-5)
-stride_a Tensor A stride (default:0)
-v cpu validation or not (default:1)
-stride_b Tensor B stride (default:0)
-prec precision (default:fp16)
-stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
```
```
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
d20c20a6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -43,16 +42,37 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -43,16 +42,37 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kTilePermute
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
LayoutC
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
,
LayoutA
,
LayoutB
,
LayoutC
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_b
,
...
@@ -255,15 +275,17 @@ int main(int argc, char* argv[])
...
@@ -255,15 +275,17 @@ int main(int argc, char* argv[])
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
CodegenPipelineProblem
=
ck_tile
::
BlockGemmPipelineProblem
<
ADataType
,
using
CodegenGemmTraits
=
ck_tile
::
BDataType
,
TileGemmTraits
<
kPadA
,
kPadB
,
kPadC
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
;
AccDataType
,
CodegenGemmShape
,
using
CodegenPipelineProblem
=
ck_tile
::
kPadA
,
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
kPadB
,
kPadC
>
;
using
CodegenGemmPipeline
=
ck_tile
::
BlockGemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
<
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
invoke_gemm
<
ck_tile
::
half_t
,
invoke_gemm
<
ck_tile
::
half_t
,
matrix_a_layout
,
matrix_a_layout
,
...
@@ -341,7 +363,13 @@ int main(int argc, char* argv[])
...
@@ -341,7 +363,13 @@ int main(int argc, char* argv[])
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
c_dimensions
);
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
c_dimensions
);
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
(
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
c_buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
c_buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
...
...
example/ck_tile/04_img2col/README.md
View file @
d20c20a6
...
@@ -6,7 +6,8 @@ This folder contains example for Image to Column using ck_tile tile-programming
...
@@ -6,7 +6,8 @@ This folder contains example for Image to Column using ck_tile tile-programming
```
```
# in the root of ck_tile
# in the root of ck_tile
mkdir build && cd build
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_img2col -j
make tile_example_img2col -j
```
```
This will result in an executable
`build/bin/tile_example_img2col`
This will result in an executable
`build/bin/tile_example_img2col`
include/ck/config.h.in
View file @
d20c20a6
...
@@ -97,13 +97,6 @@
...
@@ -97,13 +97,6 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif
#endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
//
//
// CK kernels which support XDL (MI series)
// CK kernels which support XDL (MI series)
//
//
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
d20c20a6
...
@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventDestroy
(
start
));
hip_check_error
(
hipEventDestroy
(
stop
));
return
total_time
/
nrepeat
;
return
total_time
/
nrepeat
;
}
}
else
else
...
@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventDestroy
(
start
));
hip_check_error
(
hipEventDestroy
(
stop
));
return
total_time
/
nrepeat
;
return
total_time
/
nrepeat
;
}
}
else
else
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
View file @
d20c20a6
...
@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -390,7 +390,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -390,7 +390,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
View file @
d20c20a6
...
@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -518,7 +518,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -518,7 +518,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
...
@@ -575,7 +576,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -575,7 +576,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
View file @
d20c20a6
...
@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -504,7 +504,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -504,7 +504,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
...
...
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
View file @
d20c20a6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "device_base.hpp"
#include "device_base.hpp"
...
@@ -37,7 +37,7 @@ struct DeviceCGemm : public BaseOperator
...
@@ -37,7 +37,7 @@ struct DeviceCGemm : public BaseOperator
index_t
KRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
)
=
0
;
index_t
StrideC
)
const
=
0
;
};
};
template
<
typename
AElementwiseOperation
,
template
<
typename
AElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
d20c20a6
...
@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[
maybe_unused
]]
index_t
K
,
[[
maybe_unused
]]
index_t
K
,
[[
maybe_unused
]]
index_t
StrideA
,
[[
maybe_unused
]]
index_t
StrideA
,
[[
maybe_unused
]]
index_t
StrideB
,
[[
maybe_unused
]]
index_t
StrideB
,
index_t
StrideC
)
override
index_t
StrideC
)
const
override
{
{
return
2
*
sizeof
(
CDataType
)
*
GetCElementSpaceSize
(
M
,
N
,
StrideC
);
return
2
*
sizeof
(
CDataType
)
*
GetCElementSpaceSize
(
M
,
N
,
StrideC
);
}
}
std
::
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
base_arg
)
const
override
{
const
auto
*
parg
=
dynamic_cast
<
const
Argument
*>
(
base_arg
);
if
(
!
parg
)
{
std
::
ostringstream
err
;
err
<<
"Provided argument pointer is not of an Argument class!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
return
GetWorkspaceSize
(
parg
->
M
,
parg
->
N
,
parg
->
K
,
parg
->
StrideA
,
parg
->
StrideB
,
parg
->
StrideC
);
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
View file @
d20c20a6
...
@@ -64,7 +64,7 @@ __global__ void
...
@@ -64,7 +64,7 @@ __global__ void
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
return
;
return
;
const
auto
StrideAs
=
gemm_desc_ptr
[
group_id
].
StrideAs
;
const
auto
StrideAs
=
gemm_desc_ptr
[
group_id
].
StrideAs
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
View file @
d20c20a6
...
@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
{
{
skipped_group_count_
++
;
skipped_group_count_
++
;
continue
;
continue
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
View file @
d20c20a6
...
@@ -109,7 +109,7 @@ __global__ void
...
@@ -109,7 +109,7 @@ __global__ void
N
=
gemm_desc_ptr
[
group_id
].
N
;
N
=
gemm_desc_ptr
[
group_id
].
N
;
K
=
gemm_desc_ptr
[
group_id
].
K
;
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
{
{
grid_size_grp
=
0
;
grid_size_grp
=
0
;
continue
;
continue
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
d20c20a6
...
@@ -68,7 +68,7 @@ __global__ void
...
@@ -68,7 +68,7 @@ __global__ void
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
return
;
return
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment