Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6fcaeada
"...composable_kernel_rocm.git" did not exist on "5aa9273b4f51897372d6c40465692de6de132243"
Commit
6fcaeada
authored
Oct 15, 2024
by
Astha Rai
Browse files
fixed merge conflict after merge with develop
parents
fc7a1825
d02a92cc
Changes
122
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
467 additions
and
80 deletions
+467
-80
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+5
-1
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+52
-18
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+7
-8
example/ck_tile/02_layernorm2d/README.md
example/ck_tile/02_layernorm2d/README.md
+2
-1
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+3
-1
example/ck_tile/03_gemm/README.md
example/ck_tile/03_gemm/README.md
+14
-6
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+44
-16
example/ck_tile/04_img2col/CMakeLists.txt
example/ck_tile/04_img2col/CMakeLists.txt
+3
-0
example/ck_tile/04_img2col/README.md
example/ck_tile/04_img2col/README.md
+13
-0
example/ck_tile/04_img2col/image_to_column.cpp
example/ck_tile/04_img2col/image_to_column.cpp
+170
-0
example/ck_tile/04_img2col/image_to_column.hpp
example/ck_tile/04_img2col/image_to_column.hpp
+105
-0
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+1
-0
include/ck/config.h.in
include/ck/config.h.in
+0
-7
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+6
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
...or_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
+5
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
+10
-8
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
+5
-4
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+17
-1
No files found.
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
6fcaeada
...
@@ -9,7 +9,10 @@
...
@@ -9,7 +9,10 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include "mask.hpp"
#include "bias.hpp"
#include "bias.hpp"
#include <type_traits>
#include <type_traits>
#include <utility>
#include <variant>
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
FmhaBwdTypeConfig
;
struct
FmhaBwdTypeConfig
;
...
@@ -135,7 +138,8 @@ struct fmha_bwd_args
...
@@ -135,7 +138,8 @@ struct fmha_bwd_args
ck_tile
::
index_t
mask_type
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
float
p_drop
;
float
p_undrop
;
float
p_undrop
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
;
};
};
template
<
typename
FmhaBwdDQDKDVKernel
>
template
<
typename
FmhaBwdDQDKDVKernel
>
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
6fcaeada
...
@@ -122,6 +122,9 @@ auto create_args(int argc, char* argv[])
...
@@ -122,6 +122,9 @@ auto create_args(int argc, char* argv[])
.
insert
(
"p_drop"
,
"0"
,
"0~1 probability of dropout"
)
.
insert
(
"p_drop"
,
"0"
,
"0~1 probability of dropout"
)
.
insert
(
"drop_seed"
,
"1"
,
"seed for random number generator"
)
.
insert
(
"drop_seed"
,
"1"
,
"seed for random number generator"
)
.
insert
(
"drop_offset"
,
"0"
,
"offset for random number generator"
)
.
insert
(
"drop_offset"
,
"0"
,
"offset for random number generator"
)
.
insert
(
"drop_prefs"
,
"0"
,
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
.
insert
(
"rotary_dim"
,
"0"
,
"RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all"
)
"rotary_dim"
,
"0"
,
"RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all"
)
...
@@ -442,6 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -442,6 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
float
p_drop
=
arg_parser
.
get_float
(
"p_drop"
);
float
p_drop
=
arg_parser
.
get_float
(
"p_drop"
);
uint64_t
drop_seed
=
arg_parser
.
get_uint64
(
"drop_seed"
);
uint64_t
drop_seed
=
arg_parser
.
get_uint64
(
"drop_seed"
);
uint64_t
drop_offset
=
arg_parser
.
get_uint64
(
"drop_offset"
);
uint64_t
drop_offset
=
arg_parser
.
get_uint64
(
"drop_offset"
);
bool
drop_prefs
=
arg_parser
.
get_bool
(
"drop_prefs"
);
if
(
p_drop
<
0.0
f
||
p_drop
>
1.0
f
)
if
(
p_drop
<
0.0
f
||
p_drop
>
1.0
f
)
{
{
std
::
cerr
<<
"The value of p_drop should be 0~1"
<<
std
::
endl
;
std
::
cerr
<<
"The value of p_drop should be 0~1"
<<
std
::
endl
;
...
@@ -552,16 +557,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -552,16 +557,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
#endif
#endif
auto
get_lengths
=
[
&
](
bool
permute
,
struct
ck_tile
::
index_t
b
/*batch*/
,
{
ck_tile
::
index_t
h
/*nhead*/
,
auto
operator
()(
bool
permute
,
ck_tile
::
index_t
s
/*seqlen*/
,
ck_tile
::
index_t
b
/*batch*/
,
ck_tile
::
index_t
d
/*hdim*/
)
{
ck_tile
::
index_t
h
/*nhead*/
,
if
(
permute
)
ck_tile
::
index_t
s
/*seqlen*/
,
return
std
::
array
<
ck_tile
::
index_t
,
4
>
{
b
,
h
,
s
,
d
};
ck_tile
::
index_t
d
/*hdim*/
)
else
{
return
std
::
array
<
ck_tile
::
index_t
,
4
>
{
b
,
s
,
h
,
d
};
if
(
permute
)
};
return
std
::
array
<
ck_tile
::
index_t
,
4
>
{
b
,
h
,
s
,
d
};
else
return
std
::
array
<
ck_tile
::
index_t
,
4
>
{
b
,
s
,
h
,
d
};
}
auto
operator
()(
bool
permute
,
ck_tile
::
index_t
ns
/*num_splits*/
,
ck_tile
::
index_t
b
/*batch*/
,
ck_tile
::
index_t
h
/*nhead*/
,
ck_tile
::
index_t
s
/*seqlen*/
,
ck_tile
::
index_t
d
/*hdim*/
)
{
if
(
permute
)
return
std
::
array
<
ck_tile
::
index_t
,
5
>
{
ns
,
b
,
h
,
s
,
d
};
else
return
std
::
array
<
ck_tile
::
index_t
,
5
>
{
ns
,
b
,
s
,
h
,
d
};
}
}
get_lengths
;
bool
is_v_rowmajor
=
vlayout
==
std
::
string
(
"r"
);
bool
is_v_rowmajor
=
vlayout
==
std
::
string
(
"r"
);
...
@@ -617,7 +639,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -617,7 +639,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
1
<
num_splits
||
use_kvcache
1
<
num_splits
||
use_kvcache
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
num_splits
,
batch
,
nhead
,
max
_seqlen_q
,
hdim_v
}
?
get_lengths
(
o_perm
,
num_splits
,
shape_
batch
,
nhead
,
shape
_seqlen_q
,
hdim_v
)
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
1
,
1
,
1
,
1
,
1
});
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
1
,
1
,
1
,
1
,
1
});
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// batch mode of lse data layout is [batch, nhead, seqlen_q]
...
@@ -739,6 +761,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -739,6 +761,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
need_append_kvcache
?
cache_seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
need_append_kvcache
?
cache_seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
ck_tile
::
DeviceMem
rotary_cos_buf
(
rotary_cos_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
rotary_cos_buf
(
rotary_cos_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
rotary_sin_buf
(
rotary_sin_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
rotary_sin_buf
(
rotary_sin_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
drop_seed_buf
(
drop_prefs
?
sizeof
(
uint64_t
)
:
0
);
ck_tile
::
DeviceMem
drop_offset_buf
(
drop_prefs
?
sizeof
(
uint64_t
)
:
0
);
ck_tile
::
DeviceMem
randval_buf
(
randval_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
randval_buf
(
randval_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
alibi_slope_buf
(
alibi_slope_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
alibi_slope_buf
(
alibi_slope_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
block_table_buf
(
block_table_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
block_table_buf
(
block_table_host
.
get_element_space_size_in_bytes
());
...
@@ -757,6 +781,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -757,6 +781,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
cache_seqlen_k_buf
.
ToDevice
(
need_append_kvcache
?
cache_seqlen_ks
.
data
()
:
nullptr
);
cache_seqlen_k_buf
.
ToDevice
(
need_append_kvcache
?
cache_seqlen_ks
.
data
()
:
nullptr
);
rotary_cos_buf
.
ToDevice
(
rotary_cos_host
.
data
());
rotary_cos_buf
.
ToDevice
(
rotary_cos_host
.
data
());
rotary_sin_buf
.
ToDevice
(
rotary_sin_host
.
data
());
rotary_sin_buf
.
ToDevice
(
rotary_sin_host
.
data
());
drop_seed_buf
.
ToDevice
(
drop_prefs
?
&
drop_seed
:
nullptr
);
drop_offset_buf
.
ToDevice
(
drop_prefs
?
&
drop_offset
:
nullptr
);
alibi_slope_buf
.
ToDevice
(
alibi_slope_host
.
data
());
alibi_slope_buf
.
ToDevice
(
alibi_slope_host
.
data
());
block_table_buf
.
ToDevice
(
block_table_host
.
data
());
block_table_buf
.
ToDevice
(
block_table_host
.
data
());
cache_batch_idx_buf
.
ToDevice
(
cache_batch_idx_host
.
data
());
cache_batch_idx_buf
.
ToDevice
(
cache_batch_idx_host
.
data
());
...
@@ -854,7 +880,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -854,7 +880,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}();
}();
const
ck_tile
::
index_t
stride_bias
=
(
i_perm
?
shape_seqlen_k
:
1
*
shape_seqlen_k
);
const
ck_tile
::
index_t
stride_bias
=
(
i_perm
?
shape_seqlen_k
:
1
*
shape_seqlen_k
);
const
ck_tile
::
index_t
stride_randval
=
(
max_seqlen_k
);
const
ck_tile
::
index_t
stride_randval
=
(
max_seqlen_k
);
const
ck_tile
::
index_t
stride_o_acc
=
hdim_v
;
const
ck_tile
::
index_t
stride_o_acc
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
)
;
const
ck_tile
::
index_t
stride_o
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
);
const
ck_tile
::
index_t
stride_o
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
);
// setup nhead_stride_* arguments
// setup nhead_stride_* arguments
const
ck_tile
::
index_t
nhead_stride_q
=
(
i_perm
?
shape_seqlen_q
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_q
=
(
i_perm
?
shape_seqlen_q
*
hdim_q
:
hdim_q
);
...
@@ -881,7 +907,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -881,7 +907,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
nhead_stride_randval
=
(
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_randval
=
(
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_lse
=
shape_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lse
=
shape_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lse_acc
=
shape_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lse_acc
=
shape_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_o_acc
=
(
max
_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o_acc
=
(
o_perm
?
shape
_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
// setup batch_stride_* arguments
// setup batch_stride_* arguments
const
ck_tile
::
index_t
batch_stride_q
=
(
nhead
*
shape_seqlen_q
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_q
=
(
nhead
*
shape_seqlen_q
*
hdim_q
);
...
@@ -897,12 +923,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -897,12 +923,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
batch_stride_randval
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_randval
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_lse
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse_acc
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse_acc
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_o_acc
=
(
nhead
*
max
_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o_acc
=
(
nhead
*
shape
_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_block_table
=
(
max_num_page_blocks
/
batch
);
const
ck_tile
::
index_t
batch_stride_block_table
=
(
max_num_page_blocks
/
batch
);
// setup split_stride_* arguments (only used in split-kv kernel)
// setup split_stride_* arguments (only used in split-kv kernel)
const
ck_tile
::
index_t
split_stride_lse_acc
=
(
shape_batch
*
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
split_stride_lse_acc
=
(
shape_batch
*
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
split_stride_o_acc
=
(
batch
*
nhead
*
max
_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
split_stride_o_acc
=
(
shape_
batch
*
nhead
*
shape
_seqlen_q
*
hdim_v
);
args
.
q_ptr
=
q_buf
.
GetDeviceBuffer
();
args
.
q_ptr
=
q_buf
.
GetDeviceBuffer
();
args
.
k_ptr
=
k_buf
.
GetDeviceBuffer
();
args
.
k_ptr
=
k_buf
.
GetDeviceBuffer
();
...
@@ -996,9 +1022,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -996,9 +1022,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
args
.
nhead_stride_randval
=
nhead_stride_randval
;
args
.
nhead_stride_randval
=
nhead_stride_randval
;
args
.
batch_stride_randval
=
batch_stride_randval
;
args
.
batch_stride_randval
=
batch_stride_randval
;
args
.
p_drop
=
p_drop
;
args
.
p_drop
=
p_drop
;
args
.
s_randval
=
s_randval
;
args
.
s_randval
=
s_randval
;
args
.
drop_seed_offset
=
std
::
tie
(
drop_seed
,
drop_offset
);
if
(
drop_prefs
)
{
args
.
drop_seed_offset
=
std
::
make_pair
(
drop_seed_buf
.
GetDeviceBuffer
(),
drop_offset_buf
.
GetDeviceBuffer
());
}
else
{
args
.
drop_seed_offset
=
std
::
make_pair
(
drop_seed
,
drop_offset
);
}
}
}
else
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_splitkv_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
else
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_splitkv_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
{
{
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
6fcaeada
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
#include "rotary.hpp"
#include "rotary.hpp"
#include <type_traits>
#include <type_traits>
#include <utility>
#include <variant>
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
FmhaFwdTypeConfig
;
struct
FmhaFwdTypeConfig
;
...
@@ -144,7 +146,9 @@ struct fmha_fwd_args
...
@@ -144,7 +146,9 @@ struct fmha_fwd_args
float
p_drop
;
float
p_drop
;
bool
s_randval
;
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
;
};
};
struct
fmha_fwd_splitkv_args
struct
fmha_fwd_splitkv_args
...
@@ -398,10 +402,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
...
@@ -398,10 +402,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args
.
nhead_stride_bias
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
nhead_stride_o_acc
,
args
.
batch_stride_k
,
args
.
batch_stride_k
,
// only used for paged-kvcache
args
.
batch_stride_v
,
args
.
batch_stride_v
,
// only used for paged-kvcache
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
,
args
.
split_stride_o_acc
,
args
.
window_size_left
,
args
.
window_size_left
,
...
@@ -475,7 +477,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
...
@@ -475,7 +477,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args
.
lse_ptr
,
args
.
lse_ptr
,
args
.
o_ptr
,
args
.
o_ptr
,
args
.
batch
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqstart_q_ptr
,
args
.
seqstart_q_ptr
,
args
.
hdim_v
,
args
.
hdim_v
,
args
.
num_splits
,
args
.
num_splits
,
...
@@ -486,7 +487,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
...
@@ -486,7 +487,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args
.
nhead_stride_o_acc
,
args
.
nhead_stride_o_acc
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_o
,
args
.
nhead_stride_o
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
);
args
.
split_stride_o_acc
);
}
}
...
@@ -497,7 +497,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
...
@@ -497,7 +497,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args
.
lse_ptr
,
args
.
lse_ptr
,
args
.
o_ptr
,
args
.
o_ptr
,
args
.
batch
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqlen_q
,
args
.
seqlen_q
,
args
.
hdim_v
,
args
.
hdim_v
,
args
.
num_splits
,
args
.
num_splits
,
...
...
example/ck_tile/02_layernorm2d/README.md
View file @
6fcaeada
...
@@ -6,7 +6,8 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm
...
@@ -6,7 +6,8 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm
```
```
# in the root of ck_tile
# in the root of ck_tile
mkdir build && cd build
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_layernorm2d_fwd -j
make tile_example_layernorm2d_fwd -j
```
```
This will result in an executable
`build/bin/tile_example_layernorm2d_fwd`
This will result in an executable
`build/bin/tile_example_layernorm2d_fwd`
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
View file @
6fcaeada
...
@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
...
@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
YDataType
,
YDataType
,
MeanDataType
,
MeanDataType
,
InvStdDataType
,
InvStdDataType
,
Shape
>
;
Shape
,
true
,
true
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
...
...
example/ck_tile/03_gemm/README.md
View file @
6fcaeada
...
@@ -6,7 +6,8 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
...
@@ -6,7 +6,8 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
```
```
# in the root of ck_tile
# in the root of ck_tile
mkdir build && cd build
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_gemm_basic -j
make tile_example_gemm_basic -j
```
```
This will result in an executable
`build/bin/tile_example_gemm_basic`
This will result in an executable
`build/bin/tile_example_gemm_basic`
...
@@ -14,10 +15,17 @@ This will result in an executable `build/bin/tile_example_gemm_basic`
...
@@ -14,10 +15,17 @@ This will result in an executable `build/bin/tile_example_gemm_basic`
## example
## example
```
```
args:
args:
-m m dimension (default:3328)
-b batch size (default:1)
-n m dimension (default:4096)
-m m dimension (default:1024)
-n n dimension (default:2048)
-k k dimension (default:64)
-k k dimension (default:64)
-e epsilon (default:1e-5)
-stride_a Tensor A stride (default:0)
-v cpu validation or not (default:1)
-stride_b Tensor B stride (default:0)
-prec precision (default:fp16)
-stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
```
```
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
6fcaeada
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -41,18 +40,39 @@ template <typename LayoutA,
...
@@ -41,18 +40,39 @@ template <typename LayoutA,
float
gemm_calc
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kTilePermute
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
LayoutC
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
,
LayoutA
,
LayoutB
,
LayoutC
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_b
,
...
@@ -255,15 +275,17 @@ int main(int argc, char* argv[])
...
@@ -255,15 +275,17 @@ int main(int argc, char* argv[])
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
CodegenPipelineProblem
=
ck_tile
::
BlockGemmPipelineProblem
<
ADataType
,
using
CodegenGemmTraits
=
ck_tile
::
BDataType
,
TileGemmTraits
<
kPadA
,
kPadB
,
kPadC
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
;
AccDataType
,
CodegenGemmShape
,
using
CodegenPipelineProblem
=
ck_tile
::
kPadA
,
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
kPadB
,
kPadC
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
<
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
;
using
CodegenGemmPipeline
=
ck_tile
::
BlockGemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
invoke_gemm
<
ck_tile
::
half_t
,
invoke_gemm
<
ck_tile
::
half_t
,
matrix_a_layout
,
matrix_a_layout
,
...
@@ -341,7 +363,13 @@ int main(int argc, char* argv[])
...
@@ -341,7 +363,13 @@ int main(int argc, char* argv[])
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
c_dimensions
);
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
c_dimensions
);
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
(
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
c_buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
c_buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
...
...
example/ck_tile/04_img2col/CMakeLists.txt
0 → 100644
View file @
6fcaeada
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable
(
tile_example_img2col EXCLUDE_FROM_ALL image_to_column.cpp
)
example/ck_tile/04_img2col/README.md
0 → 100644
View file @
6fcaeada
# Image to Column
This folder contains example for Image to Column using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_img2col -j
```
This will result in an executable
`build/bin/tile_example_img2col`
example/ck_tile/04_img2col/image_to_column.cpp
0 → 100644
View file @
6fcaeada
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstring>
#include "ck_tile/host.hpp"
#include "image_to_column.hpp"
// Host API implementation
template
<
>
float
image_to_column
(
const
image_to_column_traits
&
traits
,
const
image_to_column_args
<
2
>&
args
,
const
ck_tile
::
stream_config
&
stream_conf
)
{
if
(
traits
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
constexpr
ck_tile
::
index_t
NDimSpatial
=
2
;
constexpr
ck_tile
::
index_t
VectorSize
=
8
;
using
thread_tile
=
ck_tile
::
sequence
<
8
,
8
>
;
using
warp_tile
=
ck_tile
::
sequence
<
64
,
64
>
;
using
block_tile
=
ck_tile
::
sequence
<
128
,
128
>
;
using
Shape
=
ck_tile
::
TileImageToColumnShape
<
thread_tile
,
warp_tile
,
block_tile
>
;
using
InDataType
=
ck_tile
::
half_t
;
using
OutDataType
=
ck_tile
::
half_t
;
using
PipelineProblem
=
ck_tile
::
BlockImageToColumnProblem
<
InDataType
,
OutDataType
,
Shape
,
NDimSpatial
,
VectorSize
,
VectorSize
>
;
using
Kernel
=
ck_tile
::
ImageToColumn
<
PipelineProblem
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_in
,
args
.
p_out
,
args
.
G
,
args
.
N
,
args
.
C
,
args
.
input_spatial_lengths
,
args
.
filter_spatial_lengths
,
args
.
output_spatial_lengths
,
args
.
image_g_n_c_wis_strides
,
args
.
gemm_g_m_k_strides
,
args
.
conv_filter_strides
,
args
.
conv_filter_dilations
,
args
.
input_left_pads
,
args
.
input_right_pads
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
N
*
args
.
output_spatial_lengths
[
0
]
*
args
.
output_spatial_lengths
[
1
],
args
.
filter_spatial_lengths
[
0
]
*
args
.
filter_spatial_lengths
[
1
]
*
args
.
C
,
args
.
G
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
2
;
float
ave_time
=
ck_tile
::
launch_kernel
(
stream_conf
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
return
0
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
constexpr
ck_tile
::
index_t
NDimSpatial
=
2
;
ExecutionConfig
config
;
ck_tile
::
conv
::
ConvParam
conv_params
=
DefaultConvParams
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_params
))
{
return
EXIT_FAILURE
;
}
if
(
conv_params
.
num_dim_spatial_
!=
NDimSpatial
)
{
std
::
cerr
<<
"unsupported # of spatial dimensions"
<<
std
::
endl
;
return
EXIT_FAILURE
;
}
using
InDataType
=
ck_tile
::
half_t
;
using
OutDataType
=
ck_tile
::
half_t
;
using
ImLayout
=
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
;
const
auto
G
=
conv_params
.
G_
;
const
auto
N
=
conv_params
.
N_
;
const
auto
C
=
conv_params
.
C_
;
const
ck_tile
::
long_index_t
NHoWo
=
N
*
std
::
accumulate
(
conv_params
.
output_spatial_lengths_
.
begin
(),
std
::
next
(
conv_params
.
output_spatial_lengths_
.
begin
(),
NDimSpatial
),
1
,
std
::
multiplies
<>
());
const
ck_tile
::
long_index_t
CYX
=
C
*
std
::
accumulate
(
conv_params
.
filter_spatial_lengths_
.
begin
(),
std
::
next
(
conv_params
.
filter_spatial_lengths_
.
begin
(),
NDimSpatial
),
1
,
std
::
multiplies
<>
());
const
auto
in_desc
=
ck_tile
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
ImLayout
>
(
conv_params
);
const
auto
out_desc
=
ck_tile
::
HostTensorDescriptor
({
G
,
NHoWo
,
CYX
});
// host verify
ck_tile
::
HostTensor
<
InDataType
>
in
(
in_desc
);
ck_tile
::
HostTensor
<
OutDataType
>
out_device
(
out_desc
);
ck_tile
::
HostTensor
<
OutDataType
>
out_host
(
out_desc
);
switch
(
config
.
init_method
)
{
case
0
:
break
;
case
1
:
ck_tile
::
FillUniformDistributionIntegerValue
<
InDataType
>
{
-
5.
f
,
5.
f
}(
in
);
break
;
default:
ck_tile
::
FillUniformDistribution
<
InDataType
>
{
-
0.5
,
0.5
}(
in
);
break
;
}
ck_tile
::
DeviceMem
in_device_buf
(
in
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
out_device_buf
(
out_device
.
get_element_space_size_in_bytes
());
in_device_buf
.
ToDevice
(
in
.
data
());
image_to_column_traits
traits
{
"fp16"
};
image_to_column_args
<
NDimSpatial
>
args
{
in_device_buf
.
GetDeviceBuffer
(),
out_device_buf
.
GetDeviceBuffer
(),
G
,
N
,
C
,
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
input_spatial_lengths_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
filter_spatial_lengths_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
output_spatial_lengths_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
+
3
>
(
in_desc
.
get_strides
()),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
3
>
(
out_desc
.
get_strides
()),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
conv_filter_strides_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
conv_filter_dilations_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
input_left_pads_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
input_right_pads_
)};
float
ave_time
=
image_to_column
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
config
.
time_kernel
});
std
::
size_t
num_btype
=
G
*
NHoWo
*
CYX
*
(
sizeof
(
OutDataType
)
+
sizeof
(
InDataType
));
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
{
// reference
ck_tile
::
reference_im2col
<
InDataType
,
OutDataType
,
NDimSpatial
>
(
in
,
out_host
,
conv_params
);
out_device_buf
.
FromDevice
(
out_device
.
data
());
pass
=
ck_tile
::
check_err
(
out_device
,
out_host
);
std
::
cout
<<
"valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
endl
;
}
return
!
pass
;
}
example/ck_tile/04_img2col/image_to_column.hpp
0 → 100644
View file @
6fcaeada
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/image_to_column.hpp"
#include <string>
#define DefaultConvParams \
ck_tile::conv::ConvParam \
{ \
2, 2, 32, 32, 32, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \
}
struct
ExecutionConfig
final
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
};
inline
void
print_help_msg
()
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)
\n
"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
<<
"arg3: time kernel (0=no, 1=yes)
\n
"
<<
ck_tile
::
conv
::
get_conv_param_parser_helper_msg
()
<<
std
::
endl
;
}
inline
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ExecutionConfig
&
config
,
ck_tile
::
conv
::
ConvParam
&
conv_params
)
{
constexpr
int
num_execution_config_args
=
3
;
// arguments for do_verification, init_method, time_kernel
constexpr
int
num_conv_param_leading_args
=
5
;
// arguments for num_dim_spatial_, G_, N_, K_, C_
constexpr
int
threshold_to_catch_partial_args
=
1
+
num_execution_config_args
;
constexpr
int
threshold_to_catch_all_args
=
threshold_to_catch_partial_args
+
num_conv_param_leading_args
;
if
(
argc
==
1
)
{
// use default
config
=
ExecutionConfig
{};
}
// catch only ExecutionConfig arguments
else
if
(
argc
==
threshold_to_catch_partial_args
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
// catch both ExecutionConfig & ConvParam arguments
else
if
(
threshold_to_catch_all_args
<
argc
&&
((
argc
-
threshold_to_catch_all_args
)
%
3
==
0
))
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
const
ck_tile
::
index_t
num_dim_spatial
=
std
::
stoi
(
argv
[
4
]);
conv_params
=
ck_tile
::
conv
::
parse_conv_param
(
num_dim_spatial
,
threshold_to_catch_partial_args
,
argv
);
}
else
{
print_help_msg
();
return
false
;
}
return
true
;
}
struct
image_to_column_traits
{
std
::
string
data_type
;
};
template
<
ck_tile
::
index_t
NDimSpatial
>
struct
image_to_column_args
{
const
void
*
p_in
;
void
*
p_out
;
const
ck_tile
::
long_index_t
G
;
const
ck_tile
::
long_index_t
N
;
const
ck_tile
::
long_index_t
C
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
input_spatial_lengths
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
filter_spatial_lengths
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
output_spatial_lengths
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
+
3
>
image_g_n_c_wis_strides
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
3
>
gemm_g_m_k_strides
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
conv_filter_strides
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
conv_filter_dilations
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
input_left_pads
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
input_right_pads
;
};
// host API
template
<
ck_tile
::
index_t
NDimSpatial
>
float
image_to_column
(
const
image_to_column_traits
&
,
const
image_to_column_args
<
NDimSpatial
>&
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/CMakeLists.txt
View file @
6fcaeada
...
@@ -5,3 +5,4 @@ include_directories(AFTER
...
@@ -5,3 +5,4 @@ include_directories(AFTER
add_subdirectory
(
01_fmha
)
add_subdirectory
(
01_fmha
)
add_subdirectory
(
02_layernorm2d
)
add_subdirectory
(
02_layernorm2d
)
add_subdirectory
(
03_gemm
)
add_subdirectory
(
03_gemm
)
add_subdirectory
(
04_img2col
)
include/ck/config.h.in
View file @
6fcaeada
...
@@ -97,13 +97,6 @@
...
@@ -97,13 +97,6 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif
#endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
//
//
// CK kernels which support XDL (MI series)
// CK kernels which support XDL (MI series)
//
//
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
6fcaeada
...
@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventDestroy
(
start
));
hip_check_error
(
hipEventDestroy
(
stop
));
return
total_time
/
nrepeat
;
return
total_time
/
nrepeat
;
}
}
else
else
...
@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventDestroy
(
start
));
hip_check_error
(
hipEventDestroy
(
stop
));
return
total_time
/
nrepeat
;
return
total_time
/
nrepeat
;
}
}
else
else
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
View file @
6fcaeada
...
@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
...
@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
}
}
template
<
>
template
<
>
__device__
static
constexpr
auto
TailScheduler
<
1
>
()
__device__
constexpr
auto
TailScheduler
<
1
>
()
{
{
// schedule
// schedule
constexpr
auto
num_ds_read_inst
=
constexpr
auto
num_ds_read_inst
=
...
@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
...
@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
}
}
template
<
>
template
<
>
__device__
static
constexpr
auto
TailScheduler
<
2
>
()
__device__
constexpr
auto
TailScheduler
<
2
>
()
{
{
// schedule
// schedule
constexpr
auto
num_ds_read_inst
=
constexpr
auto
num_ds_read_inst
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
View file @
6fcaeada
...
@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
View file @
6fcaeada
...
@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
...
@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
View file @
6fcaeada
...
@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
...
...
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
View file @
6fcaeada
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "device_base.hpp"
#include "device_base.hpp"
...
@@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator
...
@@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
=
0
;
ck
::
index_t
KBatch
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
size_t
GetWorkspaceSize
(
index_t
MRaw
,
virtual
std
::
size_t
GetWorkspaceSize
(
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
)
=
0
;
index_t
StrideC
)
const
=
0
;
};
};
template
<
typename
AElementwiseOperation
,
template
<
typename
AElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
6fcaeada
...
@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[
maybe_unused
]]
index_t
K
,
[[
maybe_unused
]]
index_t
K
,
[[
maybe_unused
]]
index_t
StrideA
,
[[
maybe_unused
]]
index_t
StrideA
,
[[
maybe_unused
]]
index_t
StrideB
,
[[
maybe_unused
]]
index_t
StrideB
,
index_t
StrideC
)
override
index_t
StrideC
)
const
override
{
{
return
2
*
sizeof
(
CDataType
)
*
GetCElementSpaceSize
(
M
,
N
,
StrideC
);
return
2
*
sizeof
(
CDataType
)
*
GetCElementSpaceSize
(
M
,
N
,
StrideC
);
}
}
std
::
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
base_arg
)
const
override
{
const
auto
*
parg
=
dynamic_cast
<
const
Argument
*>
(
base_arg
);
if
(
!
parg
)
{
std
::
ostringstream
err
;
err
<<
"Provided argument pointer is not of an Argument class!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
return
GetWorkspaceSize
(
parg
->
M
,
parg
->
N
,
parg
->
K
,
parg
->
StrideA
,
parg
->
StrideB
,
parg
->
StrideC
);
}
};
};
}
// namespace device
}
// namespace device
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment