Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
401e643e
Commit
401e643e
authored
Dec 17, 2024
by
Po Yen Chen
Browse files
Merge branch 'develop' into feature/use-larger-tile-size-for-chunk-prefill
parents
d783a8cf
fdfe2102
Changes
61
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1063 additions
and
212 deletions
+1063
-212
include/ck_tile/ref/naive_attention.hpp
include/ck_tile/ref/naive_attention.hpp
+666
-0
include/ck_tile/remod.py
include/ck_tile/remod.py
+4
-0
library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp
...ce_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp
+3
-0
library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_f8_f8_bf16/device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
...device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
+2
-0
library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt
library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt
+3
-3
modified_files.txt
modified_files.txt
+0
-10
profiler/include/profiler/profile_gemm_universal_batched_impl.hpp
.../include/profiler/profile_gemm_universal_batched_impl.hpp
+80
-68
profiler/src/profile_gemm_universal_batched.cpp
profiler/src/profile_gemm_universal_batched.cpp
+11
-9
script/process_perf_data.py
script/process_perf_data.py
+3
-3
script/process_perf_data.sh
script/process_perf_data.sh
+13
-0
script/process_qa_data.sh
script/process_qa_data.sh
+12
-0
script/run_full_performance_tests.sh
script/run_full_performance_tests.sh
+1
-1
script/run_gemm_performance_tests.sh
script/run_gemm_performance_tests.sh
+41
-0
script/run_performance_tests.sh
script/run_performance_tests.sh
+6
-15
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
+18
-24
test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc
test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc
+12
-47
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
+15
-7
test/data_type/test_custom_type.cpp
test/data_type/test_custom_type.cpp
+59
-23
test/grouped_convnd_bwd_data/CMakeLists.txt
test/grouped_convnd_bwd_data/CMakeLists.txt
+6
-2
test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_wmma.cpp
...ped_convnd_bwd_data/test_grouped_convnd_bwd_data_wmma.cpp
+108
-0
No files found.
include/ck_tile/ref/naive_attention.hpp
0 → 100644
View file @
401e643e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include <thread>
#include <string>
namespace
ck_tile
{
enum
class
naive_attention_layout_enum
{
BSHD
,
// [batch, seqlen, nhead, hdim]
BHSD
,
// [batch, nhead, seqlen, hdim]
BS3HD
,
// [batch, nhead, 3, seqlen, hdim], used when qkv are packed
PHSD
,
// [pages, nhead, page_size, hdim]
// PHSDX, // [pages, nhead, page_size/x, hdim, x], where <# used pages>*page_size = seqlen
PHDSX
,
// [pages, nhead, hdim/x, page_size, x], where <# used pages>*page_size = seqlen
PHDS
,
// [pages, nhead, hdim, page_size], where <# used pages>*page_size = seqlen
};
// will used to specialize kernel variation
enum
class
naive_attention_variation_enum
{
FLASH_BATCHED
=
0
,
// standard flash attention, or xformer/sdpa, used for training
FLASH_GROUPED
,
DECODE_PAGED
,
// decode attn, where kv token from another buffer called kvcache
};
// TODO: for simplicity, this will be used as host/device arg
struct
naive_attention_fwd_args
{
void
*
q_ptr
;
void
*
k_ptr
;
void
*
v_ptr
;
void
*
o_ptr
;
void
*
context_len_ptr
;
// [batch] used when seqlen kv come from a pointer(each element is a
// number, not cumsum)
void
*
page_table_ptr
;
// [batch, max_pages_per_seq] seqlen_kv is in different block(paged attn)
void
*
kvscale_ptr
;
// [nhead, 2(kv), hdim] used for kvcache dequant
float
scale_s
;
int
hdim
;
int
hdim_v
;
// could be cross-attn, where V and Q/K hdim are different
int
batch_q
;
int
batch_kv
;
int
batch_ratio_kv
;
// batch_q / batch_kv
int
seqlen_q
;
// in decode case, this should be 1
int
seqlen_kv
;
// if context_len_ptr is not nullptr, ignore this field
int
nhead_q
;
int
nhead_kv
;
int
nhead_ratio_kv
;
// nhead_q / nhead_kv
int
page_size
;
// if paged, the seqlen-kv per each block
int
max_pages_per_seq
;
};
// this is trait for host API
struct
naive_attention_fwd_traits
{
std
::
string
q_type
;
std
::
string
k_type
;
std
::
string
v_type
;
std
::
string
o_type
;
std
::
string
q_layout
;
std
::
string
k_layout
;
std
::
string
v_layout
;
std
::
string
o_layout
;
int
variation
;
// sync with naive_attention_variation_enum
};
// this is trait for kernel template
template
<
naive_attention_variation_enum
variation_
>
struct
naive_attention_fwd_kernel_traits
{
static
constexpr
naive_attention_variation_enum
variation
=
variation_
;
};
// for simplicity, please do not use const-reference type for the template type
template
<
typename
QType
,
typename
KType
,
typename
VType
,
typename
OType
,
typename
AccType
,
naive_attention_layout_enum
QLayout
,
naive_attention_layout_enum
KLayout
,
naive_attention_layout_enum
VLayout
,
naive_attention_layout_enum
OLayout
,
typename
Traits
>
struct
naive_attention_fwd_kernel
{
static
constexpr
bool
is_kvcache_i8
=
std
::
is_same_v
<
KType
,
int8_t
>
&&
std
::
is_same_v
<
VType
,
int8_t
>
&&
sizeof
(
QType
)
!=
1
;
// kvcache-i8 will have per head scale, we apply this scale to Q/P matrix instead of original
// K/V matrix. This can speed up conversion since Q/P usually is fp16/bf16/fp32
static
constexpr
bool
is_kvcache_i8_forward_quant
=
is_kvcache_i8
;
// TODO: hardcode
using
KVScaleType
=
float
;
using
SoftmaxType
=
float
;
using
PType
=
VType
;
// src A of gemm2, same type as V
using
p_vec_type
=
ext_vector_t
<
PType
,
16
/
sizeof
(
PType
)
>
;
static
constexpr
int
p_vec_elem
=
vector_traits
<
p_vec_type
>::
vector_size
;
__host__
__device__
naive_attention_fwd_kernel
()
{}
template
<
typename
T
,
naive_attention_layout_enum
Layout
>
struct
addresser
{
int
b
,
s
,
h
,
d
;
// batch, seqlen, nhead, hdim
T
*
base_ptr
;
__device__
addresser
(
int
b_
,
int
s_
,
int
h_
,
int
d_
,
void
*
base_ptr_
)
:
b
(
b_
),
s
(
s_
),
h
(
h_
),
d
(
d_
),
base_ptr
(
reinterpret_cast
<
T
*>
(
base_ptr_
))
{
}
// TODO: all the batch/nhead offset will accumulate to the base pointer
__device__
T
*
get_base
(
int
i_b
,
int
i_h
)
{
if
constexpr
(
Layout
==
naive_attention_layout_enum
::
BSHD
)
return
base_ptr
+
i_b
*
s
*
h
*
d
+
i_h
*
d
;
else
if
constexpr
(
Layout
==
naive_attention_layout_enum
::
BHSD
)
return
base_ptr
+
i_b
*
s
*
h
*
d
+
i_h
*
s
*
d
;
}
__device__
int
get_offset
(
int
i_s
,
int
i_d
)
{
if
constexpr
(
Layout
==
naive_attention_layout_enum
::
BSHD
)
return
i_s
*
h
*
d
+
i_d
;
else
if
constexpr
(
Layout
==
naive_attention_layout_enum
::
BHSD
)
return
i_s
*
d
+
i_d
;
}
// below set of API will directly use pointer inside this struct
__device__
void
init
(
int
i_b
,
int
i_h
)
{
base_ptr
=
get_base
(
i_b
,
i_h
);
}
__device__
T
load
(
int
i_s
,
int
i_d
)
{
return
base_ptr
[
get_offset
(
i_s
,
i_d
)];
}
__device__
void
store
(
T
value
,
int
i_s
,
int
i_d
)
{
base_ptr
[
get_offset
(
i_s
,
i_d
)]
=
value
;
}
};
template
<
typename
T
,
naive_attention_layout_enum
Layout
>
struct
page_addresser
{
int
s
,
h
,
d
;
// page_size, nhead, hdim
static
constexpr
int
x
=
16
/
sizeof
(
T
);
// pack 4 dword
T
*
base_ptr
;
int
*
page_table_ptr
;
// TODO: page table always int
int
i_h
;
// store current head
__device__
page_addresser
(
int
s_
,
int
h_
,
int
d_
,
void
*
base_ptr_
,
void
*
pptr_
)
:
s
(
s_
),
h
(
h_
),
d
(
d_
),
base_ptr
(
reinterpret_cast
<
T
*>
(
base_ptr_
)),
page_table_ptr
(
reinterpret_cast
<
int
*>
(
pptr_
))
{
}
__device__
int64_t
get_phy_page_idx
(
int
i_s
)
{
// dynamic compute page idx is simple but slow
int
page_idx
=
i_s
/
s
;
int
phy
=
page_table_ptr
[
page_idx
];
return
static_cast
<
int64_t
>
(
phy
);
}
__device__
int
get_phy_page_offset
(
int
i_s
)
{
// dynamic compute page idx is simple but slow
return
i_s
%
s
;
}
__device__
int64_t
get_offset
(
int
i_s
,
int
i_d
)
{
int
page_offset
=
get_phy_page_offset
(
i_s
);
int64_t
page_idx
=
get_phy_page_idx
(
i_s
);
int64_t
base_
=
page_idx
*
h
*
s
*
d
;
if
constexpr
(
Layout
==
naive_attention_layout_enum
::
PHSD
)
return
static_cast
<
int64_t
>
(
i_h
*
s
*
d
+
page_offset
*
d
+
i_d
)
+
base_
;
else
if
constexpr
(
Layout
==
naive_attention_layout_enum
::
PHDSX
)
{
int
d_r
=
i_d
/
x
;
int
d_x
=
i_d
%
x
;
return
static_cast
<
int64_t
>
(
i_h
*
d
*
s
+
d_r
*
s
*
x
+
page_offset
*
x
+
d_x
)
+
base_
;
}
else
if
constexpr
(
Layout
==
naive_attention_layout_enum
::
PHDS
)
{
return
static_cast
<
int64_t
>
(
i_h
*
d
*
s
+
i_d
*
s
+
page_offset
)
+
base_
;
}
}
// below set of API will directly use pointer inside this struct
__device__
void
init
(
int
/*i_b*/
,
int
i_h_
)
{
i_h
=
i_h_
;
}
__device__
T
load
(
int
i_s
,
int
i_d
)
{
return
base_ptr
[
get_offset
(
i_s
,
i_d
)];
}
__device__
void
store
(
T
/*value*/
,
int
/*i_s*/
,
int
/*i_d*/
)
{}
};
template
<
typename
T
>
struct
kvscale_addresser
{
int
h
,
d
;
// nhead, hdim
T
*
base_ptr
;
__device__
kvscale_addresser
(
int
h_
,
int
d_
,
void
*
p_
)
:
h
(
h_
),
d
(
d_
),
base_ptr
(
reinterpret_cast
<
T
*>
(
p_
))
{
}
__device__
int
get_offset
(
int
i_h
,
int
i_d
,
int
i_kv
/*0 or 1*/
)
{
// [h, 2, d]
return
i_h
*
2
*
d
+
i_kv
*
d
+
i_d
;
}
__device__
T
load
(
int
i_h
,
int
i_d
,
int
i_kv
)
{
return
base_ptr
[
get_offset
(
i_h
,
i_d
,
i_kv
)];
}
};
__device__
__host__
static
constexpr
int
get_block_size
()
{
return
256
;
}
// for simpliciy, 1 WG always compute 1 token along q, compute all token along kv
// compute all hdim from q, compute WG_SIZE hdim from v
// 1) in prefill case, seqlen_q >= 1, seqlen_kv >= 1, batch_q=batch_kv
// 2) in decode case, seqlen_q = 1, batch_q is input num-tokens, batch_kv is 1
// 3) in paged-attn case, we still use 1 WG compute all the seqlen-kv for simplicity
// TODO: could support split-kv to validate intermediate logsum
__host__
static
dim3
get_grid_size
(
naive_attention_fwd_args
args
)
{
constexpr
int
wg_size
=
get_block_size
();
auto
g
=
dim3
((
args
.
hdim_v
+
wg_size
-
1
)
/
wg_size
,
args
.
seqlen_q
,
args
.
batch_q
*
args
.
nhead_q
);
return
g
;
}
// reduce single pixel within a wave
template
<
typename
T
,
typename
F
>
__device__
constexpr
T
wave_reduce
(
T
local
,
F
reduce_f
)
{
// constexpr int wave_size = 64;
constexpr
int
reduce_stage
=
6
;
// 1<<6=64
T
v_local
=
local
;
#pragma unroll
for
(
int
i_stage
=
0
;
i_stage
<
reduce_stage
;
i_stage
++
)
{
int
src_lane
=
__lane_id
()
^
(
1
<<
i_stage
);
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
v_local
));
T
v_remote
=
bit_cast
<
T
>
(
v_remote_tmp
);
v_local
=
reduce_f
(
v_local
,
v_remote
);
}
return
v_local
;
}
// Note: this function must be called after wave_reduce
// Note: better not use this under if...else... with thread divergence (syncthreads)
template
<
typename
T
,
typename
F
>
__device__
constexpr
T
cross_wave_reduce
(
T
local
,
F
reduce_f
,
T
*
smem
)
{
constexpr
int
waves
=
4
;
constexpr
int
wave_size
=
64
;
int
lane_id
=
threadIdx
.
x
%
wave_size
;
__syncthreads
();
smem
[
threadIdx
.
x
]
=
local
;
__syncthreads
();
// the data within single wave is the same
// but for simplicity, we still use data from each lane.
T
v_local
=
smem
[
lane_id
];
#pragma unroll
for
(
int
i_stage
=
1
;
i_stage
<
waves
;
i_stage
++
)
{
T
v_remote
=
smem
[
i_stage
*
wave_size
+
lane_id
];
v_local
=
reduce_f
(
v_local
,
v_remote
);
}
return
v_local
;
}
// kernel entry point
__device__
void
operator
()(
naive_attention_fwd_args
args
)
{
constexpr
int
wg_size
=
get_block_size
();
__shared__
char
smem
[
wg_size
*
4
*
sizeof
(
float
)];
// should enough
int
i_dv
=
blockIdx
.
x
*
wg_size
+
threadIdx
.
x
;
// index of hdim_v
int
i_sq
=
blockIdx
.
y
;
// index of seqlen_q
int
i_batch
=
blockIdx
.
z
;
// index of batch_q * nhead_q
int
i_bq
=
i_batch
/
args
.
nhead_q
;
// index of batch_q
int
i_hq
=
i_batch
%
args
.
nhead_q
;
// index of nhead_q
int
i_bk
=
i_bq
/
args
.
batch_ratio_kv
;
int
i_hk
=
i_hq
/
args
.
nhead_ratio_kv
;
void
*
page_table_ptr
=
[
&
]()
{
if
constexpr
(
Traits
::
variation
==
naive_attention_variation_enum
::
DECODE_PAGED
)
{
return
reinterpret_cast
<
int
*>
(
args
.
page_table_ptr
)
+
i_bq
*
args
.
max_pages_per_seq
;
}
else
{
return
nullptr
;
}
}();
auto
q_addr
=
[
&
]()
{
if
constexpr
(
Traits
::
variation
==
naive_attention_variation_enum
::
FLASH_BATCHED
)
{
return
addresser
<
QType
,
QLayout
>
{
args
.
batch_q
,
args
.
seqlen_q
,
args
.
nhead_q
,
args
.
hdim
,
args
.
q_ptr
};
}
else
if
constexpr
(
Traits
::
variation
==
naive_attention_variation_enum
::
DECODE_PAGED
)
{
return
addresser
<
QType
,
QLayout
>
{
args
.
batch_q
,
args
.
seqlen_q
,
args
.
nhead_q
,
args
.
hdim
,
args
.
q_ptr
};
}
}();
auto
k_addr
=
[
&
]()
{
if
constexpr
(
Traits
::
variation
==
naive_attention_variation_enum
::
FLASH_BATCHED
)
{
return
addresser
<
KType
,
KLayout
>
{
args
.
batch_kv
,
args
.
seqlen_kv
,
args
.
nhead_kv
,
args
.
hdim
,
args
.
k_ptr
};
}
else
if
constexpr
(
Traits
::
variation
==
naive_attention_variation_enum
::
DECODE_PAGED
)
{
return
page_addresser
<
KType
,
KLayout
>
{
args
.
page_size
,
args
.
nhead_kv
,
args
.
hdim
,
args
.
k_ptr
,
page_table_ptr
};
}
}();
auto
v_addr
=
[
&
]()
{
if
constexpr
(
Traits
::
variation
==
naive_attention_variation_enum
::
FLASH_BATCHED
)
{
return
addresser
<
VType
,
VLayout
>
{
args
.
batch_kv
,
args
.
seqlen_kv
,
args
.
nhead_kv
,
args
.
hdim_v
,
args
.
v_ptr
};
}
else
if
constexpr
(
Traits
::
variation
==
naive_attention_variation_enum
::
DECODE_PAGED
)
{
return
page_addresser
<
VType
,
VLayout
>
{
args
.
page_size
,
args
.
nhead_kv
,
args
.
hdim_v
,
args
.
v_ptr
,
page_table_ptr
};
}
}();
auto
o_addr
=
[
&
]()
{
if
constexpr
(
Traits
::
variation
==
naive_attention_variation_enum
::
FLASH_BATCHED
)
{
return
addresser
<
OType
,
OLayout
>
{
args
.
batch_q
,
args
.
seqlen_q
,
args
.
nhead_q
,
args
.
hdim_v
,
args
.
o_ptr
};
}
else
if
constexpr
(
Traits
::
variation
==
naive_attention_variation_enum
::
DECODE_PAGED
)
{
return
addresser
<
OType
,
OLayout
>
{
args
.
batch_q
,
args
.
seqlen_q
,
args
.
nhead_q
,
args
.
hdim_v
,
args
.
o_ptr
};
}
}();
q_addr
.
init
(
i_bq
,
i_hq
);
k_addr
.
init
(
i_bk
,
i_hk
);
v_addr
.
init
(
i_bk
,
i_hk
);
o_addr
.
init
(
i_bq
,
i_hq
);
auto
f_max
=
[](
auto
x_
,
auto
y_
)
{
return
max
(
x_
,
y_
);
};
auto
f_sum
=
[](
auto
x_
,
auto
y_
)
{
return
x_
+
y_
;
};
auto
f_absmax_f32
=
[](
float
v_0_
,
float
v_1_
)
{
float
rtn
;
asm
volatile
(
"v_max_f32 %0, abs(%1), abs(%2)"
:
"=v"
(
rtn
)
:
"v"
(
v_0_
),
"v"
(
v_1_
));
return
rtn
;
};
int
seqlen_kv
=
[
&
]()
{
if
constexpr
(
Traits
::
variation
==
naive_attention_variation_enum
::
FLASH_BATCHED
)
{
return
args
.
seqlen_kv
;
}
else
if
constexpr
(
Traits
::
variation
==
naive_attention_variation_enum
::
DECODE_PAGED
)
{
return
reinterpret_cast
<
int
*>
(
args
.
context_len_ptr
)[
i_bq
];
}
}();
SoftmaxType
row_max
=
-
numeric
<
SoftmaxType
>::
infinity
();
SoftmaxType
l
{
0
};
AccType
o_acc
=
{
0
};
int
sk_loops
=
(
seqlen_kv
+
wg_size
-
1
)
/
wg_size
;
float
qf_scale
=
.0
f
;
kvscale_addresser
<
KVScaleType
>
kvscale_addr
{
args
.
nhead_kv
,
args
.
hdim
,
args
.
kvscale_ptr
};
if
constexpr
(
is_kvcache_i8_forward_quant
)
{
// AccType is i32 now, seqlen_q = 1, hdim up to 256
float
q
=
0
;
float
k_s
=
0
;
if
(
static_cast
<
int
>
(
threadIdx
.
x
)
<
args
.
hdim
)
{
q
=
type_convert
<
float
>
(
q_addr
.
load
(
0
,
threadIdx
.
x
));
k_s
=
type_convert
<
float
>
(
kvscale_addr
.
load
(
i_hk
,
threadIdx
.
x
,
0
));
}
// 1) we apply the k scale to q
float
q_forwarded
=
q
*
k_s
;
// 2) apply smooth-quant
// find absmax
float
qf_max
=
wave_reduce
(
q_forwarded
,
f_absmax_f32
);
qf_max
=
cross_wave_reduce
(
qf_max
,
f_absmax_f32
,
reinterpret_cast
<
float
*>
(
smem
));
// per-token scale
qf_scale
=
qf_max
/
127.0
;
// devide by scale
q
=
q
/
qf_scale
;
// fp32->i8
int8_t
quantized_q
=
static_cast
<
int8_t
>
(
q
);
__syncthreads
();
reinterpret_cast
<
int8_t
*>
(
smem
)[
threadIdx
.
x
]
=
quantized_q
;
__syncthreads
();
// after above process, we have 2 data
// 1) int8 q data stored in smem(no need to reload)
// 2) per-token scale qf_scale, to be mul after 1st gemm
}
for
(
int
i_loop1
=
0
;
i_loop1
<
sk_loops
;
i_loop1
++
)
{
int
i_sk
=
i_loop1
*
wg_size
+
threadIdx
.
x
;
// gemm-1
SoftmaxType
s_softmax
=
-
numeric
<
SoftmaxType
>::
infinity
();
if
(
i_sk
<
seqlen_kv
)
{
AccType
s_acc
{
0
};
// clear for every loop
for
(
auto
i_dq
=
0
;
i_dq
<
args
.
hdim
;
i_dq
++
)
{
if
constexpr
(
is_kvcache_i8_forward_quant
)
{
int8_t
q
=
reinterpret_cast
<
int8_t
*>
(
smem
)[
i_dq
];
auto
k
=
k_addr
.
load
(
i_sk
,
i_dq
);
s_acc
+=
type_convert
<
AccType
>
(
q
)
*
type_convert
<
AccType
>
(
k
);
}
else
{
auto
q
=
q_addr
.
load
(
i_sq
,
i_dq
);
// q will have duplicate load
auto
k
=
k_addr
.
load
(
i_sk
,
i_dq
);
s_acc
+=
type_convert
<
AccType
>
(
q
)
*
type_convert
<
AccType
>
(
k
);
}
}
// scale
s_softmax
=
type_convert
<
SoftmaxType
>
(
s_acc
);
s_softmax
*=
type_convert
<
SoftmaxType
>
(
args
.
scale_s
*
ck_tile
::
log2e_v
<
SoftmaxType
>
);
if
constexpr
(
is_kvcache_i8_forward_quant
)
{
s_softmax
*=
qf_scale
;
// post scale the per-token factor
}
}
// s->p
float
pf_scale
=
0.
;
// used for i8 quant
{
// softmax, find max
SoftmaxType
old_max
=
row_max
;
SoftmaxType
cur_max
=
wave_reduce
(
s_softmax
,
f_max
);
cur_max
=
cross_wave_reduce
(
cur_max
,
f_max
,
reinterpret_cast
<
SoftmaxType
*>
(
smem
));
row_max
=
max
(
old_max
,
cur_max
);
// update row_max
// softmax, exp(i_elem - max)
SoftmaxType
p_compute
=
__builtin_amdgcn_exp2f
(
s_softmax
-
row_max
);
// compute exp_sum
SoftmaxType
row_sum
=
wave_reduce
(
p_compute
,
f_sum
);
row_sum
=
cross_wave_reduce
(
row_sum
,
f_sum
,
reinterpret_cast
<
SoftmaxType
*>
(
smem
));
// l, pre-scall o_acc
SoftmaxType
tmp
=
__builtin_amdgcn_exp2f
(
old_max
-
row_max
);
l
=
tmp
*
l
+
row_sum
;
o_acc
=
type_convert
<
AccType
>
(
type_convert
<
SoftmaxType
>
(
o_acc
)
*
tmp
);
// prepare the p_compute into smem, to let every thread read same p_compute and do
// 2nd gemm
if
constexpr
(
is_kvcache_i8_forward_quant
)
{
float
v_s
=
0
;
if
(
static_cast
<
int
>
(
threadIdx
.
x
)
<
args
.
hdim_v
)
{
v_s
=
type_convert
<
float
>
(
kvscale_addr
.
load
(
i_hk
,
threadIdx
.
x
,
1
));
}
// 1) we apply the v scale to p
float
p_forwarded
=
p_compute
*
v_s
;
// 2) apply smooth-quant
// find absmax
float
pf_max
=
wave_reduce
(
p_forwarded
,
f_absmax_f32
);
pf_max
=
cross_wave_reduce
(
pf_max
,
f_absmax_f32
,
reinterpret_cast
<
float
*>
(
smem
));
// per-token scale
pf_scale
=
pf_max
/
127.0
;
// devide by scale
p_compute
=
p_compute
/
pf_scale
;
// fp32->i8
int8_t
quantized_p
=
static_cast
<
int8_t
>
(
p_compute
);
__syncthreads
();
reinterpret_cast
<
int8_t
*>
(
smem
)[
threadIdx
.
x
]
=
quantized_p
;
__syncthreads
();
// after above process, we have 2 data
// 1) int8 p data stored in smem(no need to reload)
// 2) per-token scale pf_scale, to be mul after 2nd gemm
}
else
{
__syncthreads
();
reinterpret_cast
<
PType
*>
(
smem
)[
threadIdx
.
x
]
=
type_convert
<
PType
>
(
p_compute
);
__syncthreads
();
}
}
// gemm-2, simple loop over vector by vector
constexpr
int
gemm_2_loop
=
wg_size
/
p_vec_elem
;
{
AccType
o_acc_local
=
{
0
};
int
sk_start
=
i_loop1
*
wg_size
;
// we start from the first seqlen_kv element
for
(
int
i_loop2
=
0
;
i_loop2
<
gemm_2_loop
;
i_loop2
++
)
{
p_vec_type
p_vec
=
reinterpret_cast
<
p_vec_type
*>
(
smem
)[
i_loop2
];
#pragma unroll
for
(
int
i_j
=
0
;
i_j
<
p_vec_elem
;
i_j
++
)
{
int
sv_offset
=
i_loop2
*
p_vec_elem
+
i_j
;
int
i_sv
=
sk_start
+
sv_offset
;
VType
v
=
0.
f
;
if
(
i_dv
<
args
.
hdim_v
&&
i_sv
<
seqlen_kv
)
{
v
=
v_addr
.
load
(
i_sv
,
i_dv
);
}
o_acc_local
+=
type_convert
<
AccType
>
(
p_vec
[
i_j
])
*
type_convert
<
AccType
>
(
v
);
}
}
if
constexpr
(
is_kvcache_i8_forward_quant
)
{
// apply pr scale to local acc
o_acc_local
=
type_convert
<
AccType
>
(
type_convert
<
float
>
(
o_acc_local
)
*
pf_scale
);
}
o_acc
+=
o_acc_local
;
}
}
// post scale o_acc
{
SoftmaxType
tmp
=
l
==
0.
f
?
0.
f
:
1.
f
/
l
;
// in case masking
o_acc
=
type_convert
<
AccType
>
(
type_convert
<
SoftmaxType
>
(
o_acc
)
*
tmp
);
}
// store O
if
(
i_dv
<
args
.
hdim_v
)
o_addr
.
store
(
type_convert
<
OType
>
(
o_acc
),
i_sq
,
i_dv
);
}
};
#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_() \
{ \
using ktraits_ = \
naive_attention_fwd_kernel_traits<static_cast<naive_attention_variation_enum>( \
variation_)>; \
using k_ = naive_attention_fwd_kernel<q_type_, \
k_type_, \
v_type_, \
o_type_, \
acc_type_, \
q_layout_, \
k_layout_, \
v_layout_, \
o_layout_, \
ktraits_>; \
dim3 grids = k_::get_grid_size(a); \
r = ck_tile::launch_kernel(s, \
ck_tile::make_kernel(k_{}, grids, k_::get_block_size(), 0, a)); \
}
#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_() \
if(t.variation == 0 && t.q_layout == "bshd" && t.k_layout == "bshd" && t.v_layout == "bshd" && \
t.o_layout == "bshd") \
{ \
constexpr auto q_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BSHD; \
constexpr int variation_ = 0; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
} \
else if(t.variation == 0 && t.q_layout == "bhsd" && t.k_layout == "bhsd" && \
t.v_layout == "bhsd" && t.o_layout == "bhsd") \
{ \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr int variation_ = 0; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
} \
else if(t.variation == 2 && t.q_layout == "bhsd" && t.k_layout == "phdsx" && \
t.v_layout == "phds" && t.o_layout == "bhsd") \
{ \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::PHDSX; \
constexpr auto v_layout_ = naive_attention_layout_enum::PHDS; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr int variation_ = 2; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
}
//
CK_TILE_HOST
float
naive_attention_fwd
(
naive_attention_fwd_traits
t
,
naive_attention_fwd_args
a
,
ck_tile
::
stream_config
s
)
{
float
r
=
-
1
;
// TODO: do not explicitly create too much instance!
if
(
t
.
q_type
==
"fp16"
&&
t
.
k_type
==
"fp16"
&&
t
.
v_type
==
"fp16"
&&
t
.
o_type
==
"fp16"
)
{
using
q_type_
=
fp16_t
;
using
k_type_
=
fp16_t
;
using
v_type_
=
fp16_t
;
using
o_type_
=
fp16_t
;
using
acc_type_
=
float
;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
}
else
if
(
t
.
q_type
==
"bf16"
&&
t
.
k_type
==
"bf16"
&&
t
.
v_type
==
"bf16"
&&
t
.
o_type
==
"bf16"
)
{
using
q_type_
=
bf16_t
;
using
k_type_
=
bf16_t
;
using
v_type_
=
bf16_t
;
using
o_type_
=
bf16_t
;
using
acc_type_
=
float
;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
}
else
if
(
t
.
q_type
==
"bf16"
&&
t
.
k_type
==
"int8"
&&
t
.
v_type
==
"int8"
&&
t
.
o_type
==
"bf16"
)
{
using
q_type_
=
bf16_t
;
using
k_type_
=
int8_t
;
using
v_type_
=
int8_t
;
using
o_type_
=
bf16_t
;
using
acc_type_
=
int32_t
;
// NOTE!
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
}
else
if
(
t
.
q_type
==
"fp16"
&&
t
.
k_type
==
"int8"
&&
t
.
v_type
==
"int8"
&&
t
.
o_type
==
"fp16"
)
{
using
q_type_
=
fp16_t
;
using
k_type_
=
int8_t
;
using
v_type_
=
int8_t
;
using
o_type_
=
fp16_t
;
using
acc_type_
=
int32_t
;
// NOTE!
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
}
return
r
;
}
#undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
#undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_
}
// namespace ck_tile
include/ck_tile/remod.py
View file @
401e643e
...
...
@@ -7,6 +7,7 @@ import copy
NS
=
'ck_tile'
OPS
=
'ops'
REF
=
'ref'
OPS_COMMON
=
'common'
# common header will be duplicated into ops/* other module
HEADER_COMMON
=
f
"""// SPDX-License-Identifier: MIT
...
...
@@ -29,6 +30,9 @@ class submodule_t:
def
push
(
self
,
f
):
if
len
(
f
.
parents
)
!=
1
:
# ignore ./xxx.hpp
mod
=
get_module
(
f
)
# ref is supposed to include one header on demand
if
mod
==
REF
:
return
if
mod
==
OPS
:
if
mod
not
in
self
.
m
.
keys
():
self
.
m
[
mod
]
=
dict
()
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp
View file @
401e643e
...
...
@@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances =
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
224
,
256
,
64
,
8
,
8
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
2
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
224
,
64
,
8
,
8
,
16
,
16
,
8
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
160
,
64
,
8
,
8
,
16
,
16
,
8
,
5
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
160
,
64
,
8
,
8
,
32
,
32
,
1
,
5
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
160
,
128
,
64
,
8
,
8
,
32
,
32
,
5
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_f8_f8_bf16/device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
View file @
401e643e
...
...
@@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef __gfx94__
// Compute friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
F8
,
F8
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
...
...
@@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std:
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
F8
,
F8
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
F8
,
F8
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
...
...
library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt
View file @
401e643e
...
...
@@ -6,7 +6,7 @@ set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/)
# CK Codegen requires dataclass which is added in Python 3.7
# Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04
if
(
NOT CK_USE_ALTERNATIVE_PYTHON
)
find_package
(
Python
Interp 3 REQUIRED
)
find_package
(
Python
3 COMPONENTS Interpreter Development
)
else
()
message
(
"Using alternative python version"
)
set
(
EXTRA_PYTHON_PATH
)
...
...
@@ -33,7 +33,7 @@ set(FMHA_KNOWN_APIS "fwd,fwd_splitkv,fwd_appendkv,bwd")
# Note: The receipt 3 arg filters the generated backwards instances to reduce compilation time.
# With receipt 3 set, we are generating instances for datatype == {fp16 || bfp16}, bias == {no || alibi}, deterministic == off, and dpad == dvpad.
execute_process
(
COMMAND
${
P
YTHON
_EXECUTABLE
}
${
FMHA_SRC_FOLDER
}
/generate.py
COMMAND
${
P
ython3
_EXECUTABLE
}
${
FMHA_SRC_FOLDER
}
/generate.py
--list_blobs
${
FMHA_CPP_FOLDER
}
/blob_list.txt
--api
${
FMHA_KNOWN_APIS
}
--receipt 3
...
...
@@ -50,7 +50,7 @@ endif()
# With receipt 3 set, we are generating instances for datatype == {fp16 || bfp16}, bias == {no || alibi}, deterministic == off, and dpad == dvpad.
add_custom_command
(
OUTPUT
${
FMHA_GEN_BLOBS
}
COMMAND
${
P
YTHON
_EXECUTABLE
}
${
FMHA_SRC_FOLDER
}
/generate.py
COMMAND
${
P
ython3
_EXECUTABLE
}
${
FMHA_SRC_FOLDER
}
/generate.py
--output_dir
${
FMHA_CPP_FOLDER
}
--api
${
FMHA_KNOWN_APIS
}
--receipt 3
...
...
modified_files.txt
deleted
100755 → 0
View file @
d783a8cf
example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp
example/01_gemm/run_gemm_example_streamk_v2.inc
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_mnpadding_instance.cpp
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp
profiler/src/profile_gemm_universal_streamk.cpp
modified_files.txt
profiler/include/profiler/profile_gemm_universal_batched_impl.hpp
View file @
401e643e
...
...
@@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
int
StrideB
,
int
StrideC
,
int
BatchCount
,
int
KBatch
,
int
n_warmup
,
int
n_iter
,
uint64_t
rotating
=
0
)
...
...
@@ -147,89 +148,100 @@ bool profile_gemm_universal_batched_impl(int do_verification,
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_kbatch
=
0
;
// profile device op instances
for
(
auto
&
op_ptr
:
op_ptrs
)
{
std
::
unique_ptr
<
tensor_operation
::
device
::
BaseArgument
>
argument_ptr
;
// false branch for multi d dl kernel
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
{},
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
BatchCount
,
StrideA
,
StrideB
,
{},
StrideC
,
BatchStrideA
,
BatchStrideB
,
{},
BatchStrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
// re-init C to zero before profiling next kernel
c_device_buf
.
SetZero
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
std
::
vector
<
int
>
kbatch_list
=
{
1
,
2
,
4
,
8
,
16
,
19
,
32
,
38
};
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
0
,
n_warmup
,
n_iter
,
true
,
rotating_count
});
if
(
KBatch
>
0
)
{
kbatch_list
=
{
KBatch
};
}
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
BatchCount
*
M
*
N
*
K
;
for
(
std
::
size_t
i
=
0
;
i
<
kbatch_list
.
size
();
i
++
)
{
auto
kbatch_curr
=
kbatch_list
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
{},
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
BatchCount
,
StrideA
,
StrideB
,
{},
StrideC
,
BatchStrideA
,
BatchStrideB
,
{},
BatchStrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
kbatch_curr
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
)
*
BatchC
ount
;
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
0
,
n_warmup
,
n_iter
,
true
,
rotating_c
ount
})
;
float
t
flop
s
=
st
atic_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
std
::
size_t
flop
=
st
d
::
size_t
(
2
)
*
BatchCount
*
M
*
N
*
K
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
)
*
BatchCount
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
if
(
tflops
>
best_tflops
)
{
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
if
(
do_verification
)
{
c_device_buf
.
FromDevice
(
c_g_m_n_device_result
.
mData
.
data
());
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", KBatch "
<<
kbatch_curr
<<
std
::
endl
;
pass
=
pass
&
ck
::
utils
::
check_err
(
c_g_m_n_device_result
,
c_g_m_n_host_result
);
if
(
tflops
>
best_tflops
)
{
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
best_kbatch
=
kbatch_curr
;
}
if
(
do_
log
)
if
(
do_
verification
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_g_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_g_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host: "
,
c_g_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_g_m_n_device_result
.
mData
,
","
)
<<
std
::
endl
;
c_device_buf
.
FromDevice
(
c_g_m_n_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
c_g_m_n_device_result
,
c_g_m_n_host_result
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_g_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_g_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host: "
,
c_g_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_g_m_n_device_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
}
else
{
std
::
cout
<<
op_ptr
->
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
else
{
std
::
cout
<<
op_ptr
->
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
}
}
}
...
...
@@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification,
std
::
cout
<<
" B = "
<<
BatchCount
<<
" M = "
<<
M
<<
" N = "
<<
N
<<
" K = "
<<
K
<<
" StrideA = "
<<
StrideA
<<
" StrideB = "
<<
StrideB
<<
" StrideC = "
<<
StrideC
<<
": "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
" KBatch = "
<<
best_kbatch
<<
": "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
"
TFlops, "
<<
best_gb_per_sec
<<
"
GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
}
...
...
profiler/src/profile_gemm_universal_batched.cpp
View file @
401e643e
...
...
@@ -31,7 +31,7 @@ enum struct GemmDataType
int
profile_batched_gemm_universal
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
1
8
&&
argc
!=
2
1
)
if
(
argc
!=
1
9
&&
argc
!=
2
2
)
{
// clang-format off
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
...
...
@@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg6: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7: time kernel (0=n0, 1=yes)
\n
"
);
printf
(
"arg8 to 1
7
: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount
\n
"
);
printf
(
"arg8 to 1
8
: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount
, KBatch
\n
"
);
printf
(
"optional:
\n
"
);
printf
(
"arg1
8
: number of warm-up cycles (default 1)
\n
"
);
printf
(
"arg
19
: number of iterations (default 10)
\n
"
);
printf
(
"arg2
0
: memory for rotating buffer (default 0, size in MB)
\n
"
);
printf
(
"arg1
9
: number of warm-up cycles (default 1)
\n
"
);
printf
(
"arg
20
: number of iterations (default 10)
\n
"
);
printf
(
"arg2
1
: memory for rotating buffer (default 0, size in MB)
\n
"
);
// clang-format on
exit
(
1
);
}
...
...
@@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
int
n_warmup
=
1
;
int
n_iter
=
10
;
uint64_t
rotating
=
0
;
if
(
argc
==
2
1
)
if
(
argc
==
2
2
)
{
n_warmup
=
std
::
stoi
(
argv
[
1
8
]);
n_iter
=
std
::
stoi
(
argv
[
19
]);
rotating
=
std
::
stoull
(
argv
[
2
0
])
*
1024
*
1024
;
n_warmup
=
std
::
stoi
(
argv
[
1
9
]);
n_iter
=
std
::
stoi
(
argv
[
20
]);
rotating
=
std
::
stoull
(
argv
[
2
1
])
*
1024
*
1024
;
}
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
...
...
@@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
const
int
BatchStrideC
=
std
::
stoi
(
argv
[
16
]);
const
int
BatchCount
=
std
::
stoi
(
argv
[
17
]);
const
int
KBatch
=
std
::
stoi
(
argv
[
18
]);
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using
F8
=
ck
::
f8_t
;
...
...
@@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
StrideB_
,
StrideC_
,
BatchCount
,
KBatch
,
n_warmup
,
n_iter
,
rotating
);
...
...
script/process_perf_data.py
View file @
401e643e
...
...
@@ -82,7 +82,7 @@ def parse_logfile(logfile):
StrideA
=
[]
StrideB
=
[]
StrideC
=
[]
if
'perf_gemm
.log'
in
logfile
:
if
'perf_gemm
'
in
logfile
and
'gemm_bilinear'
not
in
logfile
:
for
line
in
open
(
logfile
):
if
'Best Perf'
in
line
:
lst
=
line
.
split
()
...
...
@@ -260,7 +260,7 @@ def main():
conn
=
sqlEngine
.
connect
()
#save gemm performance tests:
if
'perf_gemm
.log'
in
filename
:
if
'perf_gemm
'
in
filename
and
'gemm_bilinear'
not
in
filename
:
#write the ck_gemm_test_params table only needed once the test set changes
#post_test_params(test_list,conn)
for
i
in
range
(
1
,
len
(
results
)
+
1
):
...
...
@@ -332,7 +332,7 @@ def main():
table_name
=
"ck_fmha_bwd_tflops"
tflops_base
=
get_baseline
(
table_name
,
conn
)
store_new_test_result
(
table_name
,
results
,
testlist
,
branch_name
,
node_id
,
gpu_arch
,
compute_units
,
rocm_vers
,
hip_vers
,
environment
,
conn
)
store_new_test_result
(
table_name
,
results
,
testlist
,
branch_name
,
node_id
,
gpu_arch
,
compute_units
,
rocm_vers
,
hip_vers
,
environment
,
sqlEngine
)
conn
.
close
()
#compare the results to the baseline if baseline exists
...
...
script/process_perf_data.sh
View file @
401e643e
...
...
@@ -11,9 +11,22 @@
#process results
python3 process_perf_data.py perf_gemm.log
python3 process_perf_data.py perf_onnx_gemm.log
python3 process_perf_data.py perf_resnet50_N256.log
python3 process_perf_data.py perf_resnet50_N4.log
file
=
./perf_onnx_gemm_gfx10.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_onnx_gemm_gfx10.log
fi
file
=
./perf_onnx_gemm_gfx11.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_onnx_gemm_gfx11.log
fi
file
=
./perf_onnx_gemm_gfx12.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_onnx_gemm_gfx12.log
fi
file
=
./perf_fmha_fwd_gfx942.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_fmha_fwd_gfx942.log
...
...
script/process_qa_data.sh
View file @
401e643e
...
...
@@ -24,6 +24,18 @@ python3 process_perf_data.py perf_splitK_gemm.log
python3 process_perf_data.py perf_onnx_gemm.log
python3 process_perf_data.py perf_mixed_gemm.log
file
=
./perf_onnx_gemm_gfx10.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_onnx_gemm_gfx10.log
fi
file
=
./perf_onnx_gemm_gfx11.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_onnx_gemm_gfx11.log
fi
file
=
./perf_onnx_gemm_gfx12.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_onnx_gemm_gfx12.log
fi
file
=
./perf_fmha_fwd_gfx942.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_fmha_fwd_gfx942.log
...
...
script/run_full_performance_tests.sh
View file @
401e643e
...
...
@@ -5,7 +5,7 @@
# post your new test results to the database and compare them to the baseline
# please contact Illia.Silin@amd.com for more details
#
# run the script as "./run_full_performance_tests.sh <verification> <tag for your test environment> <branch name> <
node name>
# run the script as "./run_full_performance_tests.sh <verification> <tag for your test environment> <branch name> <node name>
# input arguments:
# verification = 0 : do not verify result correctness on CPU
# = 1 : verifuy correctness on CPU (may take a long time)
...
...
script/run_gemm_performance_tests.sh
0 → 100755
View file @
401e643e
#!/bin/bash
#
# in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/
# run the script as "./run_gemm_performance_tests.sh <verification> <tag for your test environment> <branch name> <node name> <arch>
# input arguments:
# verification = 0 : do not verify result correctness on CPU
# = 1 : verify correctness on CPU (may take a long time)
# environment tag : a string describing the specifics of your test environment
# branch name : name of the branch in git repo (git status | grep -e 'On branch')
# node name : $hostname
# arch : GPU architecture, e.g. "gfx9" or "gfx1100"
#get the command line arguments:
export
verify
=
$1
echo
'Verification: '
$verify
export
env_type
=
$2
echo
'Environment type: '
$env_type
export
branch
=
$3
echo
'Branch name: '
$branch
export
host_name
=
$4
echo
'Host name: '
$host_name
export arch
=
$5
echo
'GPU architecture: '
$arch
function
print_log_header
(){
rm
-f
$1
;
echo
'On branch '
$3
&>
$1
;
echo
'Node name: '
$4
>>
$1
;
#get GPU_arch and number of compute units from rocminfo
echo
-n
"GPU_arch: "
>>
$1
;
rocminfo |
grep
"Name:"
|
grep
"gfx"
>>
$1
;
rocminfo |
grep
"Compute Unit:"
>>
$1
;
hipcc
--version
|
grep
-e
'HIP version'
>>
$1
;
echo
'Environment type: '
$2
>>
$1
;
/opt/rocm/bin/amdclang++
--version
|
grep
-e
'InstalledDir'
>>
$1
;
}
#run ONNX gemm tests
export
onnx_log
=
"perf_onnx_gemm_
$arch
.log"
print_log_header
$onnx_log
$env_type
$branch
$host_name
./profile_onnx_gemm.sh gemm 0 0
$verify
1 0 1 2>&1 |
tee
-a
$onnx_log
./profile_onnx_gemm.sh gemm 1 0
$verify
1 0 1 2>&1 |
tee
-a
$onnx_log
script/run_performance_tests.sh
View file @
401e643e
#!/bin/bash
#
# in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/
# run the script as "./run_performance_tests.sh <verification> <tag for your test environment> <branch name> <
node name>
# run the script as "./run_performance_tests.sh <verification> <tag for your test environment> <branch name> <node name>
# input arguments:
# verification = 0 : do not verify result correctness on CPU
# = 1 : verify correctness on CPU (may take a long time)
...
...
@@ -51,20 +51,11 @@ print_log_header $gemm_log $env_type $branch $host_name
./profile_gemm.sh gemm 2 3
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 3 3
$verify
1 0 1 |
tee
-a
$gemm_log
#run grouped_fwd fp16 tests
export
grouped_conv_fwd_log
=
"perf_grouped_conv_fwd_fp16.log"
print_log_header
$conv_fwd_log
$env_type
$branch
$host_name
./profile_grouped_conv_fwd.sh grouped_conv_fwd 1 1 0
$verify
1 0 1 256 2>&1 |
tee
-a
$grouped_conv_fwd_log
#run grouped_bwd_data fp16 tests
export
grouped_conv_bwd_data_log
=
"perf_grouped_conv_bwd_data_fp16.log"
print_log_header
$grouped_conv_bwd_data_log
$env_type
$branch
$host_name
./profile_grouped_conv_bwd_data.sh grouped_conv_bwd_data 1 1
$verify
1 0 1 256 2>&1 |
tee
-a
$grouped_conv_bwd_data_log
#run grouped_bwd_weight fp16 tests
export
grouped_conv_bwd_weight_log
=
"perf_grouped_conv_bwd_weight_fp16.log"
print_log_header
$grouped_conv_bwd_weight_log
$env_type
$branch
$host_name
./profile_grouped_conv_bwd_weight.sh grouped_conv_bwd_weight 1 1
$verify
1 0 1 256 1 2>&1 |
tee
-a
$grouped_conv_bwd_weight_log
#run ONNX gemm tests
export
onnx_log
=
"perf_onnx_gemm.log"
print_log_header
$onnx_log
$env_type
$branch
$host_name
./profile_onnx_gemm.sh gemm 0 0
$verify
1 0 1 2>&1 |
tee
-a
$onnx_log
./profile_onnx_gemm.sh gemm 1 0
$verify
1 0 1 2>&1 |
tee
-a
$onnx_log
#run resnet50 tests
export
resnet256_log
=
"perf_resnet50_N256.log"
...
...
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
View file @
401e643e
...
...
@@ -8,35 +8,29 @@
#include "ck_tile/host.hpp"
#include "test_gemm_mem_pipeline_util.hpp"
using
F16
=
ck_tile
::
half_t
;
using
F32
=
float
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
static
constexpr
auto
Intrawave
=
ck_tile
::
GemmPipelineScheduler
::
Intrawave
;
static
constexpr
auto
Interwave
=
ck_tile
::
GemmPipelineScheduler
::
Interwave
;
template
<
typename
Tuple
>
class
TestCkTileGemmMemPipelineIntrawave
:
public
TestCkTileGemmMemPipeline
<
Tuple
,
Intrawave
>
{
};
template
<
typename
Tuple
>
class
TestCkTileGemmMemPipelineInterwave
:
public
TestCkTileGemmMemPipeline
<
Tuple
,
Interwave
>
{
};
using
F16
=
ck_tile
::
half_t
;
using
F32
=
float
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Intrawave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestCkTileGemmMemPipelineIntrawave
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestCkTileGemmMemPipelineInterwave
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestCkTileGemmMemPipeline
,
KernelTypes
);
#include "test_gemm_mem_pipeline_ut_cases.inc"
test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc
View file @
401e643e
...
...
@@ -3,11 +3,7 @@
#pragma once
//------------------------------------------------------------------------------------------------
// INTERWAVE SCHEDULER
//------------------------------------------------------------------------------------------------
TYPED_TEST
(
TestCkTileGemmMemPipelineInterwave
,
SmallM
)
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
SmallM
)
{
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
constexpr
int
N
=
1024
;
...
...
@@ -17,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, SmallM)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipeline
Interwave
,
MidLargeM
)
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
MidLargeM
)
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
1024
;
...
...
@@ -27,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, MidLargeM)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipeline
Interwave
,
PaddK
)
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
PaddK
)
{
std
::
vector
<
int
>
Ms
{
127
};
constexpr
int
N
=
1024
;
...
...
@@ -37,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, PaddK)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipeline
Interwave
,
Regular
)
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
Regular
)
{
std
::
vector
<
int
>
Ms
{
512
};
constexpr
int
N
=
1024
;
...
...
@@ -47,46 +43,15 @@ TYPED_TEST(TestCkTileGemmMemPipelineInterwave, Regular)
this
->
Run
(
M
,
N
,
K
);
}
//------------------------------------------------------------------------------------------------
// INTRAWAVE SCHEDULER
//------------------------------------------------------------------------------------------------
TYPED_TEST
(
TestCkTileGemmMemPipelineIntrawave
,
SmallM
)
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
NotSupportedArgument
)
{
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
320
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
}
constexpr
int
M
=
512
;
constexpr
int
N
=
1025
;
constexpr
int
K
=
513
;
TYPED_TEST
(
TestCkTileGemmMemPipelineIntrawave
,
MidLargeM
)
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
320
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
}
constexpr
bool
PadM
=
false
;
constexpr
bool
PadN
=
false
;
constexpr
bool
PadK
=
false
;
TYPED_TEST
(
TestCkTileGemmMemPipelineIntrawave
,
PaddK
)
{
std
::
vector
<
int
>
Ms
{
127
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
432
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipelineIntrawave
,
Regular
)
{
std
::
vector
<
int
>
Ms
{
512
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
512
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
EXPECT_THROW
((
this
->
template
Run
<
PadM
,
PadN
,
PadK
>
(
M
,
N
,
K
)),
std
::
runtime_error
);
}
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
View file @
401e643e
...
...
@@ -11,7 +11,7 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
template
<
typename
Tuple
,
ck_tile
::
GemmPipelineScheduler
Scheduler_
>
template
<
typename
Tuple
>
class
TestCkTileGemmMemPipeline
:
public
::
testing
::
Test
{
protected:
...
...
@@ -22,7 +22,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
Scheduler
=
std
::
tuple_element_t
<
7
,
Tuple
>::
value
;
// TODO: expose tile size through test t-param ?
struct
gemm_args
...
...
@@ -39,6 +39,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
ck_tile
::
index_t
stride_C
;
};
template
<
bool
PadM
,
bool
PadN
,
bool
PadK
>
void
invoke_gemm
(
const
gemm_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// TODO: This should be parameterized in tests
...
...
@@ -54,9 +55,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
bool
kPadM
=
true
;
constexpr
bool
kPadN
=
true
;
constexpr
bool
kPadK
=
true
;
constexpr
bool
kPadM
=
PadM
;
constexpr
bool
kPadN
=
PadN
;
constexpr
bool
kPadK
=
PadK
;
constexpr
int
kBlockPerCu
=
1
;
...
...
@@ -107,6 +108,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
{
throw
std
::
runtime_error
(
"Wrong! Arguments not supported! Skipping gemm!
\n
"
);
}
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
...
...
@@ -212,6 +218,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
void
SetUp
()
override
{
k_batches_
=
{
1
};
}
template
<
bool
PadM
=
true
,
bool
PadN
=
true
,
bool
PadK
=
true
>
void
Run
(
const
int
M
,
const
int
N
,
const
int
K
,
...
...
@@ -221,10 +228,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
{
for
(
auto
kb
:
k_batches_
)
{
RunSingle
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
kb
);
RunSingle
<
PadM
,
PadN
,
PadK
>
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
kb
);
}
}
template
<
bool
PadM
,
bool
PadN
,
bool
PadK
>
void
RunSingle
(
const
int
M
,
const
int
N
,
const
int
K
,
...
...
@@ -301,7 +309,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
invoke_gemm
(
args
,
ck_tile
::
stream_config
{
nullptr
,
false
});
invoke_gemm
<
PadM
,
PadN
,
PadK
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
false
});
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
bool
pass
=
true
;
...
...
test/data_type/test_custom_type.cpp
View file @
401e643e
...
...
@@ -51,8 +51,11 @@ TEST(Custom_bool, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_bool_t
>()(
Number
<
i
>
{})
=
custom_bool_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_bool_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_bool_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_bool_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_bool_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -129,8 +132,11 @@ TEST(Custom_int8, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_int8_t
>()(
Number
<
i
>
{})
=
custom_int8_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_int8_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_int8_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_int8_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_int8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -207,8 +213,11 @@ TEST(Custom_uint8, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_uint8_t
>()(
Number
<
i
>
{})
=
custom_uint8_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_uint8_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_uint8_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_uint8_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_uint8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -287,8 +296,11 @@ TEST(Custom_f8, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_f8_t
>()(
Number
<
i
>
{})
=
custom_f8_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_f8_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_f8_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_f8_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_f8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -369,8 +381,11 @@ TEST(Custom_bf8, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_bf8_t
>()(
Number
<
i
>
{})
=
custom_bf8_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_bf8_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_bf8_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_bf8_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_bf8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -450,8 +465,11 @@ TEST(Custom_half, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_half_t
>()(
Number
<
i
>
{})
=
custom_half_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_half_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_half_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_half_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_half_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -533,8 +551,11 @@ TEST(Custom_bhalf, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_bhalf_t
>()(
Number
<
i
>
{})
=
custom_bhalf_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_bhalf_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_bhalf_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_bhalf_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_bhalf_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -615,8 +636,11 @@ TEST(Custom_float, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_float_t
>()(
Number
<
i
>
{})
=
custom_float_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_float_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_float_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_float_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_float_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -693,8 +717,11 @@ TEST(Custom_double, TestAsType)
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_double_t
>()(
Number
<
i
>
{})
=
custom_double_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_double_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
custom_double_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
custom_double_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_double_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
...
...
@@ -813,8 +840,11 @@ TEST(Complex_half, TestAsType)
right_vec
.
template
AsType
<
complex_half_t
>()(
Number
<
i
>
{})
=
complex_half_t
{
test_vec
.
at
(
num_elem
*
i
),
test_vec
.
at
(
num_elem
*
i
+
1
)};
});
// copy the vector
vector_type
<
complex_half_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
complex_half_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
complex_half_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
complex_half_t
>()(
Number
<
i
>
{}).
real
,
...
...
@@ -907,8 +937,11 @@ TEST(FP8OCP, TestAsType)
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
));
});
// copy the vector
vector_type
<
f8_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
f8_t
,
size
>
left_vec
;
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
f8_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
...
...
@@ -984,8 +1017,11 @@ TEST(BF8OCP, TestAsType)
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
));
});
// copy the vector
vector_type
<
bf8_t
,
size
>
left_vec
{
right_vec
};
// check copy assignment op
left_vec
=
right_vec
;
// overwrite right_vec with 0s
right_vec
=
vector_type
<
bf8_t
,
size
>
{};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
...
...
test/grouped_convnd_bwd_data/CMakeLists.txt
View file @
401e643e
add_gtest_executable
(
test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_xdl
_wmma
.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_data
_xdl
test_grouped_convnd_bwd_data_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance
)
target_link_libraries
(
test_grouped_convnd_bwd_data_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance
)
endif
()
add_gtest_executable
(
test_grouped_convnd_bwd_data_wmma test_grouped_convnd_bwd_data_wmma.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_data_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance
)
endif
()
add_gtest_executable
(
test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp
)
if
(
result EQUAL 0
)
...
...
test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_wmma.cpp
0 → 100644
View file @
401e643e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "profiler/profile_grouped_conv_bwd_data_impl.hpp"
template
<
typename
Tuple
>
class
TestGroupedConvndBwdDataWmma
:
public
::
testing
::
Test
{
protected:
using
DataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
OutLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
WeiLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
InLayout
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
template
<
ck
::
index_t
NDimSpatial
>
void
Run
()
{
EXPECT_FALSE
(
conv_params
.
empty
());
bool
pass
=
true
;
for
(
auto
&
param
:
conv_params
)
{
pass
=
pass
&&
ck
::
profiler
::
profile_grouped_conv_bwd_data_impl
<
NDimSpatial
,
OutLayout
,
WeiLayout
,
InLayout
,
DataType
,
DataType
,
DataType
>
(
true
,
// do_verification
1
,
// init_method: integer value
false
,
// do_log
false
,
// time_kernel
param
);
}
EXPECT_TRUE
(
pass
);
}
};
using
namespace
ck
::
tensor_layout
::
convolution
;
using
KernelTypes2d
=
::
testing
::
Types
<
std
::
tuple
<
ck
::
half_t
,
GNHWK
,
GKYXC
,
GNHWC
>
,
std
::
tuple
<
int8_t
,
GNHWK
,
GKYXC
,
GNHWC
>
,
std
::
tuple
<
ck
::
half_t
,
NHWGK
,
GKYXC
,
NHWGC
>
,
std
::
tuple
<
int8_t
,
NHWGK
,
GKYXC
,
NHWGC
>>
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
ck
::
half_t
,
GNDHWK
,
GKZYXC
,
GNDHWC
>
,
std
::
tuple
<
int8_t
,
GNDHWK
,
GKZYXC
,
GNDHWC
>
,
std
::
tuple
<
ck
::
half_t
,
NDHWGK
,
GKZYXC
,
NDHWGC
>
,
std
::
tuple
<
int8_t
,
NDHWGK
,
GKZYXC
,
NDHWGC
>>
;
template
<
typename
Tuple
>
class
TestGroupedConvndBwdDataWmma2d
:
public
TestGroupedConvndBwdDataWmma
<
Tuple
>
{
};
template
<
typename
Tuple
>
class
TestGroupedConvndBwdDataWmma3d
:
public
TestGroupedConvndBwdDataWmma
<
Tuple
>
{
};
TYPED_TEST_SUITE
(
TestGroupedConvndBwdDataWmma2d
,
KernelTypes2d
);
TYPED_TEST_SUITE
(
TestGroupedConvndBwdDataWmma3d
,
KernelTypes3d
);
TYPED_TEST
(
TestGroupedConvndBwdDataWmma2d
,
Test2D
)
{
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
(
{
2
,
2
,
4
,
192
,
192
,
{
3
,
3
},
{
28
,
28
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
{
2
,
2
,
128
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
{
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
(
{
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
8
,
8
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
8
,
8
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
8
,
8
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
template
Run
<
2
>();
}
TYPED_TEST
(
TestGroupedConvndBwdDataWmma3d
,
Test3D
)
{
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
(
{
3
,
2
,
16
,
128
,
256
,
{
1
,
1
,
1
},
{
7
,
7
,
7
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
conv_params
.
push_back
(
{
3
,
2
,
2
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
2
,
32
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
32
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
64
,
3
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
template
Run
<
3
>();
}
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment