Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1b616990
Commit
1b616990
authored
Feb 05, 2025
by
aska-0096
Browse files
Merge branch 'develop' of
https://github.com/ROCm/composable_kernel
into update_cka8w8_uc
parents
af30d6b6
800cf897
Changes
574
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
303 additions
and
58 deletions
+303
-58
python/ck4inductor/universal_gemm/gen_instances.py
python/ck4inductor/universal_gemm/gen_instances.py
+7
-6
python/test/test_gen_instances.py
python/test/test_gen_instances.py
+46
-0
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+1
-1
script/process_perf_data.py
script/process_perf_data.py
+14
-0
script/process_perf_data.sh
script/process_perf_data.sh
+16
-0
script/process_qa_data.sh
script/process_qa_data.sh
+16
-0
test/CMakeLists.txt
test/CMakeLists.txt
+70
-3
test/ck_tile/CMakeLists.txt
test/ck_tile/CMakeLists.txt
+1
-0
test/ck_tile/batched_gemm/test_batched_gemm.cpp
test/ck_tile/batched_gemm/test_batched_gemm.cpp
+1
-1
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+21
-29
test/ck_tile/data_type/CMakeLists.txt
test/ck_tile/data_type/CMakeLists.txt
+4
-0
test/ck_tile/data_type/test_pk_int4.cpp
test/ck_tile/data_type/test_pk_int4.cpp
+65
-0
test/ck_tile/gemm/test_gemm_pipeline.cpp
test/ck_tile/gemm/test_gemm_pipeline.cpp
+15
-13
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
+26
-5
No files found.
Too many changes to show.
To preserve performance only
574 of 574+
files are displayed.
Plain diff
Email patch
python/ck4inductor/universal_gemm/gen_instances.py
View file @
1b616990
...
...
@@ -68,12 +68,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
template_args
.
insert
(
2
,
tuple
())
# ds layout
template_args
.
insert
(
6
,
tuple
())
# ds dtype
try
:
new_instance
=
CKGemmOperation
(
*
template_args
,
# type: ignore[arg-type]
)
op_instances
.
append
(
new_instance
)
except
TypeError
as
e
:
log
.
debug
(
f
"
{
e
}
when parsing
{
line
}
"
)
return
op_instances
...
...
python/test/test_gen_instances.py
0 → 100644
View file @
1b616990
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
import
logging
import
unittest
from
ck4inductor.universal_gemm.gen_instances
import
(
gen_ops_library
as
gen_gemm_ops_library
,
)
from
ck4inductor.universal_gemm.gen_instances
import
(
gen_ops_preselected
as
gen_gemm_ops_preselected
,
)
from
ck4inductor.grouped_conv_fwd.gen_instances
import
(
gen_conv_ops_library
as
gen_conv_ops_library
,
)
from
ck4inductor.batched_universal_gemm.gen_instances
import
(
gen_ops_library
as
gen_batched_gemm_ops_library
,
)
log
=
logging
.
getLogger
(
__name__
)
class
TestGenInstances
(
unittest
.
TestCase
):
def
test_gen_gemm_instances
(
self
):
instances
=
gen_gemm_ops_library
()
log
.
debug
(
"%d gemm instances from library"
%
len
(
instances
))
self
.
assertTrue
(
instances
)
def
test_preselected_gemm_instances
(
self
):
instances
=
gen_gemm_ops_preselected
()
log
.
debug
(
"%d preselected gemm instances"
%
len
(
instances
))
self
.
assertTrue
(
instances
)
def
test_gen_conv_instances
(
self
):
instances
=
gen_conv_ops_library
()
log
.
debug
(
"%d gemm instances from library"
%
len
(
instances
))
self
.
assertTrue
(
instances
)
def
test_gen_batched_gemm_instances
(
self
):
instances
=
gen_batched_gemm_ops_library
()
log
.
debug
(
"%d gemm instances from library"
%
len
(
instances
))
self
.
assertTrue
(
instances
)
script/cmake-ck-dev.sh
View file @
1b616990
...
...
@@ -15,7 +15,7 @@ else
fi
cmake
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
/
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_FLAGS
=
"-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
...
...
script/process_perf_data.py
View file @
1b616990
...
...
@@ -149,6 +149,12 @@ def parse_logfile(logfile):
lst
=
line
.
split
()
line_dict
=
dict
(
zip
(
lst
[
1
:],
lst
))
res
.
append
(
line_dict
[
'TFlops,'
])
elif
'perf_tile_gemm_basic'
in
logfile
or
'perf_tile_gemm_mem_pipeline'
in
logfile
:
for
line
in
open
(
logfile
):
if
'TFlops'
in
line
:
lst
=
line
.
split
()
line_dict
=
dict
(
zip
(
lst
[
1
:],
lst
))
res
.
append
(
line_dict
[
'TFlops,'
])
return
res
...
...
@@ -330,6 +336,14 @@ def main():
for
i
in
range
(
1
,
len
(
results
)
+
1
):
testlist
.
append
(
"Test%i"
%
i
)
table_name
=
"ck_fmha_bwd_tflops"
if
'gemm_basic_fp16'
in
filename
:
for
i
in
range
(
1
,
len
(
results
)
+
1
):
testlist
.
append
(
"Test%i"
%
i
)
table_name
=
"ck_tile_gemm_basic_fp16_tflops"
if
'gemm_mem_pipeline_fp16'
in
filename
:
for
i
in
range
(
1
,
len
(
results
)
+
1
):
testlist
.
append
(
"Test%i"
%
i
)
table_name
=
"ck_tile_gemm_mem_pipeline_fp16_tflops"
tflops_base
=
get_baseline
(
table_name
,
conn
)
store_new_test_result
(
table_name
,
results
,
testlist
,
branch_name
,
node_id
,
gpu_arch
,
compute_units
,
rocm_vers
,
hip_vers
,
environment
,
sqlEngine
)
...
...
script/process_perf_data.sh
View file @
1b616990
...
...
@@ -43,3 +43,19 @@ file=./perf_fmha_bwd_gfx90a.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_fmha_bwd_gfx90a.log
fi
file
=
./perf_tile_gemm_basic_fp16_gfx942.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_tile_gemm_basic_fp16_gfx942.log
fi
file
=
./perf_tile_gemm_basic_fp16_gfx90a.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_tile_gemm_basic_fp16_gfx90a.log
fi
file
=
./perf_tile_gemm_mem_pipeline_fp16_gfx942.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_tile_gemm_mem_pipeline_fp16_gfx942.log
fi
file
=
./perf_tile_gemm_mem_pipeline_fp16_gfx90a.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_tile_gemm_mem_pipeline_fp16_gfx90a.log
fi
script/process_qa_data.sh
View file @
1b616990
...
...
@@ -52,3 +52,19 @@ file=./perf_fmha_bwd_gfx90a.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_fmha_bwd_gfx90a.log
fi
file
=
./perf_gemm_basic_gfx942.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_gemm_basic_gfx942.log
fi
file
=
./perf_gemm_basic_gfx90a.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_gemm_basic_gfx90a.log
fi
file
=
./perf_gemm_mem_pipeline_gfx942.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_gemm_mem_pipeline_gfx942.log
fi
file
=
./perf_gemm_mem_pipeline_gfx90a.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_gemm_mem_pipeline_gfx90a.log
fi
test/CMakeLists.txt
View file @
1b616990
...
...
@@ -7,6 +7,34 @@ include(gtest)
add_custom_target
(
tests
)
# list of tests that are labelled as REGRESSION_TEST for make regression (runtime more than 30 seconds)
# all other tests are labelled as SMOKE_TEST
set
(
REGRESSION_TESTS
test_gemm_standalone_xdl_fp16
test_gemm_fp16
test_gemm_splitk
test_batched_gemm
test_gemm_universal
test_batched_gemm_softmax_gemm_fp16
test_batched_gemm_softmax_gemm_permute_fp16
test_batched_gemm_bias_softmax_gemm_permute_fp16
test_batched_gemm_softmax_gemm_permute_bf16
test_batched_gemm_bias_softmax_gemm_permute_bf16
test_grouped_gemm_splitk
test_reduce_no_index
test_reduce_with_index
test_convnd_fwd
test_convnd_bwd_data
test_grouped_convnd_fwd
test_grouped_convnd_bwd_weight
test_softmax_rank3
test_softmax_rank4
test_batchnorm_fwd_rank_4
test_batchnorm_bwd_rank_4
test_grouped_convnd_bwd_data_xdl
test_conv_tensor_rearrange
)
function
(
add_test_executable TEST_NAME
)
message
(
"adding test
${
TEST_NAME
}
"
)
set
(
result 1
)
...
...
@@ -43,6 +71,12 @@ function(add_test_executable TEST_NAME)
set
(
TEST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DPP_KERNELS AND source MATCHES
"_dpp"
)
message
(
"removing dpp test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
message
(
"removing dl test
${
source
}
"
)
...
...
@@ -66,7 +100,7 @@ function(add_test_executable TEST_NAME)
if
(
ARGN MATCHES
"_xdl"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
gfx950
)
elseif
(
ARGN MATCHES
"_smfmac"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
endif
()
...
...
@@ -82,6 +116,15 @@ function(add_test_executable TEST_NAME)
endif
()
#message("add_test returns ${result}")
set
(
result
${
result
}
PARENT_SCOPE
)
if
(
result EQUAL 0 AND NOT
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
message
(
"adding to SMOKE TEST FILTER
${
TEST_NAME
}
"
)
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"SMOKE_TEST"
)
add_dependencies
(
smoke
${
TEST_NAME
}
)
elseif
(
result EQUAL 0 AND
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
message
(
"Adding to REGRESSION TEST FILTER
${
TEST_NAME
}
"
)
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"REGRESSION_TEST"
)
add_dependencies
(
regression
${
TEST_NAME
}
)
endif
()
endfunction
()
function
(
add_gtest_executable TEST_NAME
)
...
...
@@ -126,26 +169,38 @@ function(add_gtest_executable TEST_NAME)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT TEST_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
message
(
"removing xdl test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT TEST_TARGETS MATCHES
"gfx95"
AND source MATCHES
"mx_"
)
message
(
"removing microscaling test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT TEST_TARGETS MATCHES
"gfx11"
AND NOT TEST_TARGETS MATCHES
"gfx12"
AND source MATCHES
"wmma"
)
message
(
"removing wmma test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN MATCHES
"_xdl"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
gfx950
)
elseif
(
ARGN MATCHES
"_smfmac"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
elseif
(
ARGN MATCHES
"_mx"
)
#only build mx example for gfx950
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
...
...
@@ -162,6 +217,15 @@ function(add_gtest_executable TEST_NAME)
endif
()
#message("add_gtest returns ${result}")
set
(
result
${
result
}
PARENT_SCOPE
)
if
(
result EQUAL 0 AND NOT
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
#message("adding to smoke test FILTER ${TEST_NAME}")
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"SMOKE_TEST"
)
add_dependencies
(
smoke
${
TEST_NAME
}
)
elseif
(
result EQUAL 0 AND
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
#message("Adding to REGRESSION TEST FILTER ${TEST_NAME}")
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"REGRESSION_TEST"
)
add_dependencies
(
regression
${
TEST_NAME
}
)
endif
()
endfunction
()
add_compile_options
(
-Wno-c++20-extensions
)
...
...
@@ -206,8 +270,11 @@ add_subdirectory(wrapper)
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
endif
()
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx942"
AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2
)
# smfmac needs ROCm6.2
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx942"
OR SUPPORTED_GPU_TARGETS MATCHES
"gfx950"
)
# smfmac needs ROCm6.2
add_subdirectory
(
smfmac_op
)
endif
()
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx950"
)
add_subdirectory
(
mx_mfma_op
)
endif
()
add_subdirectory
(
position_embedding
)
add_subdirectory
(
scatter_gather
)
test/ck_tile/CMakeLists.txt
View file @
1b616990
...
...
@@ -2,3 +2,4 @@ add_subdirectory(image_to_column)
add_subdirectory
(
gemm
)
add_subdirectory
(
batched_gemm
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
data_type
)
test/ck_tile/batched_gemm/test_batched_gemm.cpp
View file @
1b616990
...
...
@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
>
,
//
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
...
...
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
View file @
1b616990
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
...
...
@@ -32,9 +32,6 @@ class TestCkTileBatchedGemm : public ::testing::Test
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
...
...
@@ -51,32 +48,12 @@ class TestCkTileBatchedGemm : public ::testing::Test
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
TilePartitioner
=
ck_tile
::
GemmTile1DPartitioner
<
CodegenGemmShape
>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
...
...
@@ -88,12 +65,26 @@ class TestCkTileBatchedGemm : public ::testing::Test
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
CodegenGemmPipeline
::
BlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
CodegenPipelineProblem
::
TransposeC
>>
;
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
batch_count
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
,
args
.
batch_count
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
...
...
@@ -186,6 +177,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
k_batch
=
1
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
...
...
test/ck_tile/data_type/CMakeLists.txt
0 → 100644
View file @
1b616990
# Currently ck_tile is only built on gfx9
if
(
GPU_TARGETS MATCHES
"gfx9"
)
add_gtest_executable
(
test_ck_tile_pk_int4 test_pk_int4.cpp
)
endif
()
test/ck_tile/data_type/test_pk_int4.cpp
0 → 100644
View file @
1b616990
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include <hip/hip_runtime.h>
#include "ck_tile/core.hpp"
using
ck_tile
::
bf16_t
;
using
ck_tile
::
bf16x2_t
;
using
ck_tile
::
fp16x2_t
;
using
ck_tile
::
fp32x2_t
;
using
ck_tile
::
half_t
;
using
ck_tile
::
pk_int4_t
;
TEST
(
PackedInt4
,
ConvertToFloat
)
{
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
constexpr
float
first_input_val
=
7.
f
;
constexpr
float
second_input_val
=
-
1.
f
;
#else
constexpr
float
first_input_val
=
-
1.
f
;
constexpr
float
second_input_val
=
7.
f
;
#endif
uint8_t
data
=
0b11110111
;
// {-1, 7}
pk_int4_t
in
=
ck_tile
::
bit_cast
<
int8_t
>
(
data
);
fp32x2_t
out
=
ck_tile
::
pk_int4_t_to_fp32x2_t
(
in
);
EXPECT_EQ
(
out
.
x
,
first_input_val
);
EXPECT_EQ
(
out
.
y
,
second_input_val
);
}
TEST
(
PackedInt4
,
ConvertToHalf
)
{
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
const
half_t
first_input_val
=
ck_tile
::
type_convert
<
half_t
>
(
7.
f
);
const
half_t
second_input_val
=
ck_tile
::
type_convert
<
half_t
>
(
-
1.
f
);
#else
const
half_t
first_input_val
=
ck_tile
::
type_convert
<
half_t
>
(
-
1.
f
);
const
half_t
second_input_val
=
ck_tile
::
type_convert
<
half_t
>
(
7.
f
);
#endif
uint8_t
data
=
0b11110111
;
// {-1, 7}
pk_int4_t
in
=
ck_tile
::
bit_cast
<
int8_t
>
(
data
);
fp16x2_t
out
=
ck_tile
::
pk_int4_t_to_halfx2_t
(
in
);
EXPECT_EQ
(
out
.
x
,
first_input_val
);
EXPECT_EQ
(
out
.
y
,
second_input_val
);
}
TEST
(
PackedInt4
,
ConvertToBHalf
)
{
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
const
bf16_t
first_input_val
=
ck_tile
::
type_convert
<
bf16_t
>
(
7.
f
);
const
bf16_t
second_input_val
=
ck_tile
::
type_convert
<
bf16_t
>
(
-
1.
f
);
#else
const
bf16_t
first_input_val
=
ck_tile
::
type_convert
<
bf16_t
>
(
-
1.
f
);
const
bf16_t
second_input_val
=
ck_tile
::
type_convert
<
bf16_t
>
(
7.
f
);
#endif
uint8_t
data
=
0b11110111
;
// {-1, 7}
pk_int4_t
in
=
ck_tile
::
bit_cast
<
int8_t
>
(
data
);
bf16x2_t
out
=
ck_tile
::
pk_int4_t_to_bfloat16x2_t
(
in
);
EXPECT_EQ
(
out
.
x
,
first_input_val
);
EXPECT_EQ
(
out
.
y
,
second_input_val
);
}
test/ck_tile/gemm/test_gemm_pipeline.cpp
View file @
1b616990
...
...
@@ -14,26 +14,28 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Intrawave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
using
Mem
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Mem
>
;
//
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
//
ck_tile::GemmPipelineScheduler::Interwave>;
//
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
//
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
//
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
//
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
//
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
//
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
//
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
//
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
//
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>
;
// clang-format on
...
...
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
View file @
1b616990
...
...
@@ -10,7 +10,13 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM)
constexpr
int
K
=
320
;
for
(
int
M
:
Ms
)
{
if
constexpr
(
std
::
is_same_v
<
typename
TestFixture
::
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
EXPECT_THROW
((
this
->
Run
(
M
,
N
,
K
)),
std
::
runtime_error
);
else
this
->
Run
(
M
,
N
,
K
);
}
}
TYPED_TEST
(
TestCkTileGemmPipeline
,
MidLargeM
)
...
...
@@ -18,14 +24,29 @@ TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
320
;
constexpr
int
VecLoadSize
=
8
;
for
(
int
M
:
Ms
)
{
if
constexpr
(
std
::
is_same_v
<
typename
TestFixture
::
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
// TODO: Can we anyhow deduce used vector load size?
if
(
M
%
VecLoadSize
==
0
)
this
->
Run
(
M
,
N
,
K
);
else
EXPECT_THROW
((
this
->
Run
(
M
,
N
,
K
)),
std
::
runtime_error
);
}
else
{
this
->
Run
(
M
,
N
,
K
);
}
}
}
TYPED_TEST
(
TestCkTileGemmPipeline
,
PaddK
)
{
std
::
vector
<
int
>
Ms
{
12
7
};
std
::
vector
<
int
>
Ms
{
12
8
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
432
;
...
...
Prev
1
…
25
26
27
28
29
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment