Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9533a172
Unverified
Commit
9533a172
authored
Dec 02, 2024
by
Illia Silin
Committed by
GitHub
Dec 02, 2024
Browse files
Merge branch 'develop' into codegen-enable-hiprtc
parents
c2cf0733
50ee4267
Changes
503
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
623 additions
and
276 deletions
+623
-276
profiler/src/profile_grouped_gemm_two_stage.cpp
profiler/src/profile_grouped_gemm_two_stage.cpp
+0
-228
profiler/src/profile_layernorm_fwd.cpp
profiler/src/profile_layernorm_fwd.cpp
+1
-1
python/ck4inductor/batched_universal_gemm/gen_instances.py
python/ck4inductor/batched_universal_gemm/gen_instances.py
+149
-0
python/ck4inductor/batched_universal_gemm/op.py
python/ck4inductor/batched_universal_gemm/op.py
+99
-0
python/ck4inductor/grouped_conv_fwd/gen_instances.py
python/ck4inductor/grouped_conv_fwd/gen_instances.py
+1
-3
script/process_perf_data.py
script/process_perf_data.py
+2
-2
script/process_qa_data.sh
script/process_qa_data.sh
+1
-0
test/CMakeLists.txt
test/CMakeLists.txt
+6
-6
test/ck_tile/CMakeLists.txt
test/ck_tile/CMakeLists.txt
+1
-0
test/ck_tile/batched_gemm/CMakeLists.txt
test/ck_tile/batched_gemm/CMakeLists.txt
+4
-0
test/ck_tile/batched_gemm/test_batched_gemm.cpp
test/ck_tile/batched_gemm/test_batched_gemm.cpp
+29
-0
test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc
test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc
+9
-0
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+225
-0
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
+16
-3
test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc
test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc
+55
-4
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
+19
-18
test/gemm_universal/test_gemm_universal_xdl.cpp
test/gemm_universal/test_gemm_universal_xdl.cpp
+2
-2
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
+2
-2
test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases_xdl.cpp
...ed_convnd_fwd/test_grouped_convnd_fwd_large_cases_xdl.cpp
+2
-1
test/grouped_gemm/CMakeLists.txt
test/grouped_gemm/CMakeLists.txt
+0
-6
No files found.
profiler/src/profile_grouped_gemm_two_stage.cpp
deleted
100644 → 0
View file @
c2cf0733
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
#include "profiler_operation_registry.hpp"
enum
struct
GemmMatrixLayout
{
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
};
enum
struct
GemmDataType
{
F16_F16_F16
,
// 0
BF16_INT8_BF16
,
// 1
BF16_BF16_BF16
// 2
};
#define OP_NAME "grouped_gemm_two_stage"
#define OP_DESC "Grouped GEMM TwoStage"
namespace
{
std
::
vector
<
int
>
argToIntArray
(
char
*
input
)
{
std
::
vector
<
int
>
out
;
std
::
istringstream
in
(
input
);
std
::
string
item
;
while
(
std
::
getline
(
in
,
item
,
','
))
{
out
.
push_back
(
std
::
stoi
(
item
));
}
return
out
;
}
int
profile_grouped_gemm_two_stage
(
int
argc
,
char
*
argv
[])
{
if
(
argc
<
14
)
{
std
::
cout
<<
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
<<
"arg2: data type (0: fp16; 1: bf16@int8; 2: bf16)
\n
"
<<
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);
\n
"
<<
"arg4: verification (0: no; 1: yes)
\n
"
<<
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
<<
"arg6: print tensor value (0: no; 1: yes)
\n
"
<<
"arg7: time kernel (0=n0, 1=yes)
\n
"
<<
"arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)
\n
"
<<
"arg15: kbatch value (default 1)
\n
"
<<
"optional:
\n
"
<<
"arg16: number of warm-up cycles (default 1)
\n
"
<<
"arg17: number of iterations (default 10)
\n
"
<<
std
::
endl
;
exit
(
1
);
}
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
7
]);
const
auto
Ms
=
argToIntArray
(
argv
[
8
]);
const
auto
Ns
=
argToIntArray
(
argv
[
9
]);
const
auto
Ks
=
argToIntArray
(
argv
[
10
]);
auto
StrideAs
=
argToIntArray
(
argv
[
11
]);
auto
StrideBs
=
argToIntArray
(
argv
[
12
]);
auto
StrideCs
=
argToIntArray
(
argv
[
13
]);
const
int
kbatch
=
argc
==
15
?
std
::
stoi
(
argv
[
14
])
:
1
;
const
int
DefaultStrideA
=
Ks
[
0
];
const
int
DefaultStrideB
=
Ns
[
0
];
const
int
DefaultStrideC
=
Ns
[
0
];
for
(
size_t
i
=
0
;
i
<
Ms
.
size
();
++
i
)
{
StrideAs
[
i
]
=
StrideAs
[
i
]
==
-
1
?
DefaultStrideA
:
StrideAs
[
i
];
StrideBs
[
i
]
=
StrideBs
[
i
]
==
-
1
?
DefaultStrideB
:
StrideBs
[
i
];
StrideCs
[
i
]
=
StrideCs
[
i
]
==
-
1
?
DefaultStrideC
:
StrideCs
[
i
];
}
int
n_warmup
=
1
;
int
n_iter
=
10
;
if
(
argc
==
17
)
{
n_warmup
=
std
::
stoi
(
argv
[
16
]);
n_iter
=
std
::
stoi
(
argv
[
17
]);
}
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_two_stage_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_INT8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_two_stage_impl
<
ck
::
bhalf_t
,
int8_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_INT8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_two_stage_impl
<
ck
::
bhalf_t
,
int8_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_two_stage_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_two_stage_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
}
else
{
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
}
return
0
;
}
}
// anonymous namespace
REGISTER_PROFILER_OPERATION
(
OP_NAME
,
OP_DESC
,
profile_grouped_gemm_two_stage
);
profiler/src/profile_layernorm_fwd.cpp
View file @
9533a172
...
...
@@ -85,7 +85,7 @@ int profile_layernorm(int argc, char* argv[])
if
(
data_type
==
ck
::
DataTypeEnum
::
Half
)
{
ck
::
profiler
::
profile_layernorm_impl
<
F16
,
F16
,
F16
,
F32
,
F16
,
F
32
,
false
,
rank
>
(
ck
::
profiler
::
profile_layernorm_impl
<
F16
,
F16
,
F16
,
F32
,
F16
,
F
16
,
false
,
rank
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
length
);
}
else
if
(
data_type
==
ck
::
DataTypeEnum
::
Float
)
...
...
python/ck4inductor/batched_universal_gemm/gen_instances.py
0 → 100644
View file @
9533a172
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
import
logging
import
os
import
subprocess
from
dataclasses
import
replace
from
functools
import
lru_cache
from
typing
import
List
from
..util
import
library_path
from
.op
import
CKBatchedGemmOperation
log
=
logging
.
getLogger
(
__name__
)
def
_ck_library_dir
():
gemm_instances_path
=
os
.
path
.
join
(
library_path
(),
"src"
,
"tensor_operation_instance"
,
"gpu"
,
"gemm_universal_batched"
,
)
if
not
os
.
path
.
exists
(
gemm_instances_path
):
log
.
error
(
"CK library path %s does not exist"
,
gemm_instances_path
)
return
None
return
gemm_instances_path
def
parse_instances
(
str_instances
:
List
[
str
])
->
List
[
CKBatchedGemmOperation
]:
"""
Parse the lines containing Universal Gemm template instances into `CKBatchedGemmOperation` instances
"""
def
maybe_int
(
s
):
try
:
return
int
(
s
)
except
ValueError
:
return
s
op_instances
=
[]
for
line
in
str_instances
:
s_template_args
=
line
.
split
(
"DeviceBatchedGemmMultiD_Xdl_CShuffle_V3"
)[
-
1
].
strip
(
"<>, "
)
template_args
=
[]
i_current
=
0
while
i_current
<
len
(
s_template_args
):
if
s_template_args
[
i_current
]
==
" "
:
# skip whitespace
i_current
+=
1
continue
elif
s_template_args
[
i_current
:
i_current
+
2
]
==
"S<"
:
# parse template S<Index...>
i_next
=
s_template_args
.
find
(
">"
,
i_current
)
template_args
.
append
(
tuple
(
map
(
int
,
s_template_args
[
i_current
+
2
:
i_next
].
split
(
","
)))
)
i_current
=
i_next
+
2
else
:
# all string attributes must be either type aliases or global constants in C++
i_next
=
s_template_args
.
find
(
","
,
i_current
)
template_args
.
append
(
maybe_int
(
s_template_args
[
i_current
:
i_next
if
i_next
!=
-
1
else
None
]
)
)
if
i_next
!=
-
1
:
i_current
=
i_next
+
1
if
i_next
==
-
1
:
break
# ds layout and dtype are parsed as placeholder; reset value
template_args
[
2
]
=
tuple
()
# ds layout
template_args
[
6
]
=
tuple
()
# ds dtype
new_instance
=
CKBatchedGemmOperation
(
*
template_args
,
# type: ignore[arg-type]
)
op_instances
.
append
(
new_instance
)
return
op_instances
@
lru_cache
(
None
)
def
gen_ops_library
()
->
List
[
CKBatchedGemmOperation
]:
"""
Parse the Universal Gemm instances defined in the composable kernel library folder.
"""
ck_library_dir
=
_ck_library_dir
()
if
not
ck_library_dir
:
return
[]
grep_result
=
subprocess
.
run
(
[
"grep"
,
"-inR"
,
"DeviceBatchedGemmMultiD_Xdl_CShuffle_V3"
,
_ck_library_dir
(),
],
capture_output
=
True
,
text
=
True
,
)
op_instances
=
parse_instances
(
grep_result
.
stdout
.
strip
().
split
(
"
\n
"
))
log
.
debug
(
"ck instances from library: %d"
,
len
(
op_instances
))
schedulers
=
[
"BlockGemmPipelineScheduler::Intrawave"
,
"BlockGemmPipelineScheduler::Interwave"
,
]
gemm_specs
=
[
"GemmSpecialization::Default"
,
"GemmSpecialization::MPadding"
,
"GemmSpecialization::NPadding"
,
"GemmSpecialization::KPadding"
,
"GemmSpecialization::MNPadding"
,
"GemmSpecialization::MKPadding"
,
"GemmSpecialization::NKPadding"
,
"GemmSpecialization::MNKPadding"
,
]
# substitute templated args by looping through their domains
substitute_instances
=
[]
for
instance
in
op_instances
:
sub_scheduler
=
instance
.
block_gemm_pipeline_scheduler
==
"BlkGemmPipeSched"
sub_spec
=
instance
.
gemm_specialization
==
"GemmSpec"
schedulers_range
=
(
schedulers
if
sub_scheduler
else
[
instance
.
block_gemm_pipeline_scheduler
]
)
spec_range
=
gemm_specs
if
sub_spec
else
[
instance
.
gemm_specialization
]
for
scheduler
in
schedulers_range
:
for
spec
in
spec_range
:
substitute_instances
.
append
(
replace
(
instance
,
block_gemm_pipeline_scheduler
=
scheduler
,
gemm_specialization
=
spec
,
)
)
return
substitute_instances
if
__name__
==
"__main__"
:
print
(
gen_ops_library
())
python/ck4inductor/batched_universal_gemm/op.py
0 → 100644
View file @
9533a172
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
from
dataclasses
import
asdict
,
dataclass
from
typing
import
Optional
,
Tuple
@
dataclass
class
CKBatchedGemmOperation
:
"""
A python dataclass storing the template parameters of a CK Universal Gemm template instance
"""
a_layout
:
str
b_layout
:
str
ds_layouts
:
Tuple
[
str
]
# addmm specific
c_layout
:
str
a_element_dtype
:
str
b_element_dtype
:
str
ds_element_dtypes
:
Tuple
[
str
]
# addmm specific
c_element_dtype
:
str
acc_dtype
:
str
c_shuffle_dtype
:
str
a_elementwise_op
:
str
b_elementwise_op
:
str
c_elementwise_op
:
str
gemm_specialization
:
str
block_size
:
int
m_per_block
:
int
n_per_block
:
int
k_per_block
:
int
a_k1
:
int
b_k1
:
int
m_per_xdl
:
int
n_per_xdl
:
int
m_xdl_per_wave
:
int
n_xdl_per_wave
:
int
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_thread_cluster_arrange_order
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_src_access_order
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_src_vector_dim
:
int
a_block_transfer_src_scalar_per_vector
:
int
a_block_transfer_dst_scalar_per_vector_ak1
:
int
a_block_lds_extra_m
:
bool
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_thread_cluster_arrange_order
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_src_access_order
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_src_vector_dim
:
int
b_block_transfer_src_scalar_per_vector
:
int
b_block_transfer_dst_scalar_per_vector_bk1
:
int
b_block_lds_extra_n
:
bool
c_shuffle_m_xdl_per_wave_per_shuffle
:
int
c_shuffle_n_xdl_per_wave_per_shuffle
:
int
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
:
(
Tuple
[
int
,
int
,
int
,
int
]
)
c_shuffle_block_transfer_scalar_per_vector_n_per_block
:
Tuple
[
int
]
block_gemm_pipeline_scheduler
:
str
block_gemm_pipeline_version
:
str
a_compute_dtype
:
Optional
[
str
]
=
None
b_compute_dtype
:
Optional
[
str
]
=
None
def
name
(
self
):
# cpp alias for template instance
return
f
"ck_device_batched_gemm_multi_d_xdl_c_shuffle_v3_
{
self
.
key_name
()
}
"
def
key_name
(
self
):
# TBD; must be unique per instance. Intended to use as dict key
return
"_"
.
join
(
[
"K"
+
field_name
.
replace
(
"_"
,
""
).
lower
()
+
"V"
+
(
"x"
.
join
(
map
(
str
,
iter
(
field_value
)))
if
isinstance
(
field_value
,
tuple
)
else
str
(
field_value
).
replace
(
":"
,
""
)
)
for
field_name
,
field_value
in
self
.
dict_items
()
]
)
def
dict_items
(
self
):
return
asdict
(
self
).
items
()
python/ck4inductor/grouped_conv_fwd/gen_instances.py
View file @
9533a172
...
...
@@ -130,9 +130,7 @@ def gen_conv_ops_library() -> List[CKGroupedConvFwdOp]:
# substitute templated args by looping through their domains
substitute_instances
=
[]
for
instance
in
op_instances
:
sub_scheduler
=
(
instance
.
block_gemm_pipeline_scheduler
==
"BlkGemmPipeSched"
)
sub_scheduler
=
instance
.
block_gemm_pipeline_scheduler
==
"BlkGemmPipeSched"
sub_spec
=
instance
.
conv_forward_specialization
==
"ConvSpec"
schedulers_range
=
(
schedulers
if
sub_scheduler
else
[
instance
.
block_gemm_pipeline_scheduler
]
...
...
script/process_perf_data.py
View file @
9533a172
...
...
@@ -133,12 +133,12 @@ def parse_logfile(logfile):
if
'Best Perf'
in
line
:
lst
=
line
.
split
()
res
.
append
(
lst
[
4
])
elif
'onnx_gemm'
in
logfile
or
'mixed_gemm'
in
logfile
:
elif
'onnx_gemm'
in
logfile
:
for
line
in
open
(
logfile
):
if
'Best Perf'
in
line
:
lst
=
line
.
split
()
res
.
append
(
lst
[
33
])
elif
'splitK_gemm'
in
logfile
:
elif
'splitK_gemm'
in
logfile
or
'mixed_gemm'
in
logfile
:
for
line
in
open
(
logfile
):
if
'Best Perf'
in
line
:
lst
=
line
.
split
()
...
...
script/process_qa_data.sh
View file @
9533a172
...
...
@@ -22,6 +22,7 @@ python3 process_perf_data.py perf_gemm_bilinear.log
python3 process_perf_data.py perf_reduction.log
python3 process_perf_data.py perf_splitK_gemm.log
python3 process_perf_data.py perf_onnx_gemm.log
python3 process_perf_data.py perf_mixed_gemm.log
file
=
./perf_fmha_fwd_gfx942.log
if
[
-e
"
$file
"
]
;
then
...
...
test/CMakeLists.txt
View file @
9533a172
...
...
@@ -64,11 +64,11 @@ function(add_test_executable TEST_NAME)
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN MATCHES
"_xdl"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM TEST_TARGETS
gfx900 gfx906 gfx906:xnack-
gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
gfx10.3-generic gfx11-generic gfx12-generic
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM TEST_TARGETS
gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack-
gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN MATCHES
"_smfmac"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201
)
list
(
REMOVE_ITEM TEST_TARGETS
gfx900 gfx906 gfx906:xnack-
gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201
gfx10.3-generic gfx11-generic gfx12-generic
)
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
...
...
@@ -141,11 +141,11 @@ function(add_gtest_executable TEST_NAME)
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN MATCHES
"_xdl"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906
gfx906:xnack-
gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
gfx10.3-generic gfx11-generic gfx12-generic
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906
gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack-
gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN MATCHES
"_smfmac"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201
)
list
(
REMOVE_ITEM TEST_TARGETS
gfx900 gfx906 gfx906:xnack-
gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201
gfx10.3-generic gfx11-generic gfx12-generic
)
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
...
...
test/ck_tile/CMakeLists.txt
View file @
9533a172
add_subdirectory
(
image_to_column
)
add_subdirectory
(
gemm
)
add_subdirectory
(
batched_gemm
)
test/ck_tile/batched_gemm/CMakeLists.txt
0 → 100644
View file @
9533a172
# Currently ck_tile is only built on gfx9
if
(
GPU_TARGETS MATCHES
"gfx9"
)
add_gtest_executable
(
test_ck_tile_batched_gemm test_batched_gemm.cpp
)
endif
()
test/ck_tile/batched_gemm/test_batched_gemm.cpp
0 → 100644
View file @
9533a172
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_batched_gemm_util.hpp"
using
F16
=
ck_tile
::
half_t
;
using
F32
=
float
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
>
,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestCkTileBatchedGemm
,
KernelTypes
);
#include "test_batched_gemm_ut_cases.inc"
test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc
0 → 100644
View file @
9533a172
#pragma once
TYPED_TEST
(
TestCkTileBatchedGemm
,
Basic
)
{
constexpr
int
M
=
256
;
constexpr
int
N
=
128
;
constexpr
int
K
=
128
;
this
->
Run
(
M
,
N
,
K
);
}
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
0 → 100644
View file @
9533a172
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
template
<
typename
Tuple
>
class
TestCkTileBatchedGemm
:
public
::
testing
::
Test
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmHostArgs
{
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
void
invoke_batched_gemm
(
const
batched_gemm_kargs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
// This part comes from the Codegen
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
}
public:
void
Run
(
const
int
M
,
const
int
N
,
const
int
K
,
int
StrideA
=
128
,
int
StrideB
=
128
,
int
StrideC
=
128
,
const
int
BatchStrideA
=
32768
,
const
int
BatchStrideB
=
16384
,
const
int
BatchStrideC
=
32768
,
const
int
BatchCount
=
16
)
{
using
namespace
ck_tile
::
literals
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count_
,
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
stride
,
1
_uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
1
_uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideB
=
f_get_default_stride
(
K
,
N
,
StrideB
,
BLayout
{});
StrideC
=
f_get_default_stride
(
M
,
N
,
StrideC
,
CLayout
{});
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
K
,
StrideA
,
BatchStrideA
,
ALayout
{}));
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
BatchCount
,
K
,
N
,
StrideB
,
BatchStrideB
,
BLayout
{}));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
BatchStrideC
,
CLayout
{}));
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_m_n_dev_buf
(
c_m_n_dev_result
.
get_element_space_size_in_bytes
());
a_m_k_dev_buf
.
ToDevice
(
a_m_k
.
data
());
b_k_n_dev_buf
.
ToDevice
(
b_k_n
.
data
());
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
batched_gemm_kargs
kargs
{
a_m_k_dev_buf
.
GetDeviceBuffer
(),
b_k_n_dev_buf
.
GetDeviceBuffer
(),
c_m_n_dev_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideC
,
BatchCount
};
invoke_batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
kargs
,
ck_tile
::
stream_config
{
nullptr
,
false
});
std
::
cout
<<
"Run kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
<<
" StrideA ="
<<
StrideA
<<
" StrideB ="
<<
StrideB
<<
" StrideC ="
<<
StrideC
<<
" BatchStrideA ="
<<
BatchStrideA
<<
" BatchStrideB ="
<<
BatchStrideB
<<
" BatchStrideC ="
<<
BatchStrideC
<<
" BatchCount ="
<<
BatchCount
<<
std
::
endl
;
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
bool
pass
=
true
;
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
BatchStrideC
,
CLayout
{}));
c_m_n_host_ref
.
SetZero
();
const
auto
b_n_k
=
b_k_n
.
transpose
({
0
,
2
,
1
});
ck_tile
::
reference_batched_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k
,
b_n_k
,
c_m_n_host_ref
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
);
EXPECT_TRUE
(
pass
);
}
};
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
View file @
9533a172
...
...
@@ -11,8 +11,20 @@
using
F16
=
ck_tile
::
half_t
;
using
F32
=
float
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
static
constexpr
auto
Intrawave
=
ck_tile
::
GemmPipelineScheduler
::
Intrawave
;
static
constexpr
auto
Interwave
=
ck_tile
::
GemmPipelineScheduler
::
Interwave
;
template
<
typename
Tuple
>
class
TestCkTileGemmMemPipelineIntrawave
:
public
TestCkTileGemmMemPipeline
<
Tuple
,
Intrawave
>
{
};
template
<
typename
Tuple
>
class
TestCkTileGemmMemPipelineInterwave
:
public
TestCkTileGemmMemPipeline
<
Tuple
,
Interwave
>
{
};
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
...
...
@@ -24,6 +36,7 @@ using KernelTypes = ::testing::Types<
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestCkTileGemmMemPipeline
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestCkTileGemmMemPipelineIntrawave
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestCkTileGemmMemPipelineInterwave
,
KernelTypes
);
#include "test_gemm_mem_pipeline_ut_cases.inc"
test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc
View file @
9533a172
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
SmallM
)
//------------------------------------------------------------------------------------------------
// INTERWAVE SCHEDULER
//------------------------------------------------------------------------------------------------
TYPED_TEST
(
TestCkTileGemmMemPipelineInterwave
,
SmallM
)
{
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
320
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipelineInterwave
,
MidLargeM
)
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
320
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipelineInterwave
,
PaddK
)
{
std
::
vector
<
int
>
Ms
{
127
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
432
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipelineInterwave
,
Regular
)
{
std
::
vector
<
int
>
Ms
{
512
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
512
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
}
//------------------------------------------------------------------------------------------------
// INTRAWAVE SCHEDULER
//------------------------------------------------------------------------------------------------
TYPED_TEST
(
TestCkTileGemmMemPipelineIntrawave
,
SmallM
)
{
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
constexpr
int
N
=
1024
;
...
...
@@ -10,7 +61,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
MidLargeM
)
TYPED_TEST
(
TestCkTileGemmMemPipeline
Intrawave
,
MidLargeM
)
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
1024
;
...
...
@@ -20,7 +71,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
PaddK
)
TYPED_TEST
(
TestCkTileGemmMemPipeline
Intrawave
,
PaddK
)
{
std
::
vector
<
int
>
Ms
{
127
};
constexpr
int
N
=
1024
;
...
...
@@ -30,7 +81,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
Regular
)
TYPED_TEST
(
TestCkTileGemmMemPipeline
Intrawave
,
Regular
)
{
std
::
vector
<
int
>
Ms
{
512
};
constexpr
int
N
=
1024
;
...
...
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
View file @
9533a172
...
...
@@ -11,20 +11,21 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
template
<
typename
Tuple
>
template
<
typename
Tuple
,
ck_tile
::
GemmPipelineScheduler
Scheduler_
>
class
TestCkTileGemmMemPipeline
:
public
::
testing
::
Test
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
// TODO: expose tile size through test t-param ?
struct
gemm_
basic_
args
struct
gemm_args
{
const
void
*
p_a
;
const
void
*
p_b
;
...
...
@@ -38,7 +39,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
ck_tile
::
index_t
stride_C
;
};
void
invoke_gemm
(
const
gemm_
basic_
args
&
args
,
const
ck_tile
::
stream_config
&
s
)
void
invoke_gemm
(
const
gemm_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// TODO: This should be parameterized in tests
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
...
...
@@ -53,9 +54,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
bool
kPad
A
=
true
;
constexpr
bool
kPad
B
=
true
;
constexpr
bool
kPad
C
=
true
;
constexpr
bool
kPad
M
=
true
;
constexpr
bool
kPad
N
=
true
;
constexpr
bool
kPad
K
=
true
;
constexpr
int
kBlockPerCu
=
1
;
...
...
@@ -68,9 +69,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
false
,
kPad
C
>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPad
N
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPad
A
,
kPad
B
,
kPad
C
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPad
M
,
kPad
N
,
kPad
K
,
ALayout
,
BLayout
,
CLayout
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
...
...
@@ -89,7 +90,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
AccDataType
,
GemmShape
,
Traits
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
Scheduler
,
has_hot_loop_v
,
tail_number_v
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
...
...
@@ -108,7 +109,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Lunching kernel with args:"
std
::
cout
<<
"L
a
unching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
...
...
@@ -288,7 +289,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
gemm_
basic_
args
args
;
gemm_args
args
;
args
.
p_a
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
...
...
test/gemm_universal/test_gemm_universal_xdl.cpp
View file @
9533a172
...
...
@@ -56,7 +56,7 @@ class TestGemmUniversal_KM_NK
using
KernelTypes_MK_KN
=
::
testing
::
Types
<
// ADataType, BDataType, ComputeDataType, CDataType
std
::
tuple
<
F16
,
F16
,
F16
,
F16
>
,
#if
(
defined
CK_ENABLE_FP8)
#if defined
(
CK_ENABLE_FP8)
&& (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std
::
tuple
<
F16
,
F8
,
F16
,
F16
>
,
std
::
tuple
<
F8
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
F8
,
F8
,
F8
,
BF16
>
,
...
...
@@ -66,7 +66,7 @@ using KernelTypes_MK_KN = ::testing::Types<
using
KernelTypes_MK_NK
=
::
testing
::
Types
<
// ADataType, BDataType, ComputeDataType, CDataType
std
::
tuple
<
F16
,
F16
,
F16
,
F16
>
,
#if
(
defined
CK_ENABLE_FP8)
#if defined
(
CK_ENABLE_FP8)
&& (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std
::
tuple
<
F16
,
F8
,
F16
,
F16
>
,
std
::
tuple
<
F8
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
F8
,
F8
,
F8
,
BF16
>
,
...
...
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
View file @
9533a172
...
...
@@ -58,13 +58,13 @@ using KernelTypes1d = ::testing::Types<std::tuple<float, GNWC, GKXC, GNWK>,
using
KernelTypes2d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNHWC
,
GKYXC
,
GNHWK
>
,
std
::
tuple
<
ck
::
half_t
,
GNHWC
,
GKYXC
,
GNHWK
>
,
std
::
tuple
<
ck
::
bhalf_t
,
GNHWC
,
GKYXC
,
GNHWK
>
,
std
::
tuple
<
int8_t
,
GNHWC
,
GKYXC
,
GNHWK
>
,
std
::
tuple
<
float
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
ck
::
half_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
int8_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
float
,
NGCHW
,
GKYXC
,
NGKHW
>
,
std
::
tuple
<
ck
::
half_t
,
NGCHW
,
GKYXC
,
NGKHW
>>
;
std
::
tuple
<
ck
::
half_t
,
NGCHW
,
GKYXC
,
NGKHW
>
,
std
::
tuple
<
int8_t
,
NGCHW
,
GKYXC
,
NGKHW
>>
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
std
::
tuple
<
ck
::
half_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
...
...
test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases_xdl.cpp
View file @
9533a172
...
...
@@ -52,7 +52,8 @@ using namespace ck::tensor_layout::convolution;
using
KernelTypes2d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
ck
::
half_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
>>
;
std
::
tuple
<
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
int8_t
,
NHWGC
,
GKYXC
,
NHWGK
>>
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
NDHWGC
,
GKZYXC
,
NDHWGK
>
,
std
::
tuple
<
ck
::
half_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
>
,
...
...
test/grouped_gemm/CMakeLists.txt
View file @
9533a172
...
...
@@ -6,12 +6,6 @@ if(result EQUAL 0)
add_dependencies
(
test_grouped_gemm test_grouped_gemm_splitk
)
endif
()
add_gtest_executable
(
test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance
)
add_dependencies
(
test_grouped_gemm test_grouped_gemm_two_stage_splitk
)
endif
()
add_gtest_executable
(
test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance
)
...
...
Prev
1
…
21
22
23
24
25
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment