Commit 401e643e authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Merge branch 'develop' into feature/use-larger-tile-size-for-chunk-prefill

parents d783a8cf fdfe2102
This diff is collapsed.
...@@ -7,6 +7,7 @@ import copy ...@@ -7,6 +7,7 @@ import copy
NS = 'ck_tile' NS = 'ck_tile'
OPS = 'ops' OPS = 'ops'
REF = 'ref'
OPS_COMMON = 'common' # common header will be duplicated into ops/* other module OPS_COMMON = 'common' # common header will be duplicated into ops/* other module
HEADER_COMMON = f"""// SPDX-License-Identifier: MIT HEADER_COMMON = f"""// SPDX-License-Identifier: MIT
...@@ -29,6 +30,9 @@ class submodule_t: ...@@ -29,6 +30,9 @@ class submodule_t:
def push(self, f): def push(self, f):
if len(f.parents) != 1: # ignore ./xxx.hpp if len(f.parents) != 1: # ignore ./xxx.hpp
mod = get_module(f) mod = get_module(f)
# ref is supposed to include one header on demand
if mod == REF:
return
if mod == OPS: if mod == OPS:
if mod not in self.m.keys(): if mod not in self.m.keys():
self.m[mod] = dict() self.m[mod] = dict()
......
...@@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = ...@@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances =
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 160, 64, 8, 8, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 32, 32, 1, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 64, 8, 8, 32, 32, 5, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
......
...@@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std ...@@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef __gfx94__ #ifdef __gfx94__
// Compute friendly // Compute friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
...@@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std: ...@@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std:
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
......
...@@ -6,7 +6,7 @@ set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/) ...@@ -6,7 +6,7 @@ set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/)
# CK Codegen requires dataclass which is added in Python 3.7 # CK Codegen requires dataclass which is added in Python 3.7
# Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04 # Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04
if(NOT CK_USE_ALTERNATIVE_PYTHON) if(NOT CK_USE_ALTERNATIVE_PYTHON)
find_package(PythonInterp 3 REQUIRED) find_package(Python3 COMPONENTS Interpreter Development)
else() else()
message("Using alternative python version") message("Using alternative python version")
set(EXTRA_PYTHON_PATH) set(EXTRA_PYTHON_PATH)
...@@ -33,7 +33,7 @@ set(FMHA_KNOWN_APIS "fwd,fwd_splitkv,fwd_appendkv,bwd") ...@@ -33,7 +33,7 @@ set(FMHA_KNOWN_APIS "fwd,fwd_splitkv,fwd_appendkv,bwd")
# Note: The receipt 3 arg filters the generated backwards instances to reduce compilation time. # Note: The receipt 3 arg filters the generated backwards instances to reduce compilation time.
# With receipt 3 set, we are generating instances for datatype == {fp16 || bfp16}, bias == {no || alibi}, deterministic == off, and dpad == dvpad. # With receipt 3 set, we are generating instances for datatype == {fp16 || bfp16}, bias == {no || alibi}, deterministic == off, and dpad == dvpad.
execute_process( execute_process(
COMMAND ${PYTHON_EXECUTABLE} ${FMHA_SRC_FOLDER}/generate.py COMMAND ${Python3_EXECUTABLE} ${FMHA_SRC_FOLDER}/generate.py
--list_blobs ${FMHA_CPP_FOLDER}/blob_list.txt --list_blobs ${FMHA_CPP_FOLDER}/blob_list.txt
--api ${FMHA_KNOWN_APIS} --api ${FMHA_KNOWN_APIS}
--receipt 3 --receipt 3
...@@ -50,7 +50,7 @@ endif() ...@@ -50,7 +50,7 @@ endif()
# With receipt 3 set, we are generating instances for datatype == {fp16 || bfp16}, bias == {no || alibi}, deterministic == off, and dpad == dvpad. # With receipt 3 set, we are generating instances for datatype == {fp16 || bfp16}, bias == {no || alibi}, deterministic == off, and dpad == dvpad.
add_custom_command( add_custom_command(
OUTPUT ${FMHA_GEN_BLOBS} OUTPUT ${FMHA_GEN_BLOBS}
COMMAND ${PYTHON_EXECUTABLE} ${FMHA_SRC_FOLDER}/generate.py COMMAND ${Python3_EXECUTABLE} ${FMHA_SRC_FOLDER}/generate.py
--output_dir ${FMHA_CPP_FOLDER} --output_dir ${FMHA_CPP_FOLDER}
--api ${FMHA_KNOWN_APIS} --api ${FMHA_KNOWN_APIS}
--receipt 3 --receipt 3
......
example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp
example/01_gemm/run_gemm_example_streamk_v2.inc
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_mnpadding_instance.cpp
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp
profiler/src/profile_gemm_universal_streamk.cpp
modified_files.txt
...@@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
int StrideB, int StrideB,
int StrideC, int StrideC,
int BatchCount, int BatchCount,
int KBatch,
int n_warmup, int n_warmup,
int n_iter, int n_iter,
uint64_t rotating = 0) uint64_t rotating = 0)
...@@ -147,14 +148,23 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -147,14 +148,23 @@ bool profile_gemm_universal_batched_impl(int do_verification,
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
float best_kbatch = 0;
// profile device op instances // profile device op instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
std::unique_ptr<tensor_operation::device::BaseArgument> argument_ptr; std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38};
// false branch for multi d dl kernel
argument_ptr = if(KBatch > 0)
{
kbatch_list = {KBatch};
}
for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
{}, {},
...@@ -173,15 +183,13 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -173,15 +183,13 @@ bool profile_gemm_universal_batched_impl(int do_verification,
BatchStrideC, BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{},
kbatch_curr);
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run( float ave_time = invoker_ptr->Run(
...@@ -199,7 +207,7 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -199,7 +207,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl; << " GB/s, " << op_name << ", KBatch " << kbatch_curr << std::endl;
if(tflops > best_tflops) if(tflops > best_tflops)
{ {
...@@ -207,6 +215,7 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -207,6 +215,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
best_tflops = tflops; best_tflops = tflops;
best_ave_time = ave_time; best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
} }
if(do_verification) if(do_verification)
...@@ -219,7 +228,8 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -219,7 +228,8 @@ bool profile_gemm_universal_batched_impl(int do_verification,
{ {
LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") LogRangeAsType<float>(
std::cout << "c_host: ", c_g_m_n_host_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "c_device: ", c_g_m_n_device_result.mData, ",") std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
...@@ -229,7 +239,9 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -229,7 +239,9 @@ bool profile_gemm_universal_batched_impl(int do_verification,
} }
else else
{ {
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
} }
} }
...@@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification,
std::cout << " B = " << BatchCount << " M = " << M << " N = " << N << " K = " << K std::cout << " B = " << BatchCount << " M = " << M << " N = " << N << " K = " << K
<< " StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC << " StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC
<< ": " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " KBatch = " << best_kbatch << ": " << best_ave_time << " ms, " << best_tflops
<< " GB/s, " << best_op_name << std::endl; << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return pass;
} }
......
...@@ -31,7 +31,7 @@ enum struct GemmDataType ...@@ -31,7 +31,7 @@ enum struct GemmDataType
int profile_batched_gemm_universal(int argc, char* argv[]) int profile_batched_gemm_universal(int argc, char* argv[])
{ {
if(argc != 18 && argc != 21) if(argc != 19 && argc != 22)
{ {
// clang-format off // clang-format off
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
...@@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n"); printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=n0, 1=yes)\n"); printf("arg7: time kernel (0=n0, 1=yes)\n");
printf("arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount\n"); printf("arg8 to 18: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount, KBatch\n");
printf("optional:\n"); printf("optional:\n");
printf("arg18: number of warm-up cycles (default 1)\n"); printf("arg19: number of warm-up cycles (default 1)\n");
printf("arg19: number of iterations (default 10)\n"); printf("arg20: number of iterations (default 10)\n");
printf("arg20: memory for rotating buffer (default 0, size in MB)\n"); printf("arg21: memory for rotating buffer (default 0, size in MB)\n");
// clang-format on // clang-format on
exit(1); exit(1);
} }
...@@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
int n_warmup = 1; int n_warmup = 1;
int n_iter = 10; int n_iter = 10;
uint64_t rotating = 0; uint64_t rotating = 0;
if(argc == 21) if(argc == 22)
{ {
n_warmup = std::stoi(argv[18]); n_warmup = std::stoi(argv[19]);
n_iter = std::stoi(argv[19]); n_iter = std::stoi(argv[20]);
rotating = std::stoull(argv[20]) * 1024 * 1024; rotating = std::stoull(argv[21]) * 1024 * 1024;
} }
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
...@@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
const int BatchStrideC = std::stoi(argv[16]); const int BatchStrideC = std::stoi(argv[16]);
const int BatchCount = std::stoi(argv[17]); const int BatchCount = std::stoi(argv[17]);
const int KBatch = std::stoi(argv[18]);
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) #if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using F8 = ck::f8_t; using F8 = ck::f8_t;
...@@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
StrideB_, StrideB_,
StrideC_, StrideC_,
BatchCount, BatchCount,
KBatch,
n_warmup, n_warmup,
n_iter, n_iter,
rotating); rotating);
......
...@@ -82,7 +82,7 @@ def parse_logfile(logfile): ...@@ -82,7 +82,7 @@ def parse_logfile(logfile):
StrideA=[] StrideA=[]
StrideB=[] StrideB=[]
StrideC=[] StrideC=[]
if 'perf_gemm.log' in logfile: if 'perf_gemm' in logfile and 'gemm_bilinear' not in logfile:
for line in open(logfile): for line in open(logfile):
if 'Best Perf' in line: if 'Best Perf' in line:
lst=line.split() lst=line.split()
...@@ -260,7 +260,7 @@ def main(): ...@@ -260,7 +260,7 @@ def main():
conn = sqlEngine.connect() conn = sqlEngine.connect()
#save gemm performance tests: #save gemm performance tests:
if 'perf_gemm.log' in filename: if 'perf_gemm' in filename and 'gemm_bilinear' not in filename:
#write the ck_gemm_test_params table only needed once the test set changes #write the ck_gemm_test_params table only needed once the test set changes
#post_test_params(test_list,conn) #post_test_params(test_list,conn)
for i in range(1,len(results)+1): for i in range(1,len(results)+1):
...@@ -332,7 +332,7 @@ def main(): ...@@ -332,7 +332,7 @@ def main():
table_name="ck_fmha_bwd_tflops" table_name="ck_fmha_bwd_tflops"
tflops_base = get_baseline(table_name,conn) tflops_base = get_baseline(table_name,conn)
store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, conn) store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, sqlEngine)
conn.close() conn.close()
#compare the results to the baseline if baseline exists #compare the results to the baseline if baseline exists
......
...@@ -11,9 +11,22 @@ ...@@ -11,9 +11,22 @@
#process results #process results
python3 process_perf_data.py perf_gemm.log python3 process_perf_data.py perf_gemm.log
python3 process_perf_data.py perf_onnx_gemm.log
python3 process_perf_data.py perf_resnet50_N256.log python3 process_perf_data.py perf_resnet50_N256.log
python3 process_perf_data.py perf_resnet50_N4.log python3 process_perf_data.py perf_resnet50_N4.log
file=./perf_onnx_gemm_gfx10.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_onnx_gemm_gfx10.log
fi
file=./perf_onnx_gemm_gfx11.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_onnx_gemm_gfx11.log
fi
file=./perf_onnx_gemm_gfx12.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_onnx_gemm_gfx12.log
fi
file=./perf_fmha_fwd_gfx942.log file=./perf_fmha_fwd_gfx942.log
if [ -e "$file" ]; then if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_fwd_gfx942.log python3 process_perf_data.py perf_fmha_fwd_gfx942.log
......
...@@ -24,6 +24,18 @@ python3 process_perf_data.py perf_splitK_gemm.log ...@@ -24,6 +24,18 @@ python3 process_perf_data.py perf_splitK_gemm.log
python3 process_perf_data.py perf_onnx_gemm.log python3 process_perf_data.py perf_onnx_gemm.log
python3 process_perf_data.py perf_mixed_gemm.log python3 process_perf_data.py perf_mixed_gemm.log
file=./perf_onnx_gemm_gfx10.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_onnx_gemm_gfx10.log
fi
file=./perf_onnx_gemm_gfx11.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_onnx_gemm_gfx11.log
fi
file=./perf_onnx_gemm_gfx12.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_onnx_gemm_gfx12.log
fi
file=./perf_fmha_fwd_gfx942.log file=./perf_fmha_fwd_gfx942.log
if [ -e "$file" ]; then if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_fwd_gfx942.log python3 process_perf_data.py perf_fmha_fwd_gfx942.log
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# post your new test results to the database and compare them to the baseline # post your new test results to the database and compare them to the baseline
# please contact Illia.Silin@amd.com for more details # please contact Illia.Silin@amd.com for more details
# #
# run the script as "./run_full_performance_tests.sh <verification> <tag for your test environment> <branch name> < node name> # run the script as "./run_full_performance_tests.sh <verification> <tag for your test environment> <branch name> <node name>
# input arguments: # input arguments:
# verification = 0 : do not verify result correctness on CPU # verification = 0 : do not verify result correctness on CPU
# = 1 : verifuy correctness on CPU (may take a long time) # = 1 : verifuy correctness on CPU (may take a long time)
......
#!/bin/bash
#
# in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/
# run the script as "./run_gemm_performance_tests.sh <verification> <tag for your test environment> <branch name> <node name> <arch>
# input arguments:
# verification = 0 : do not verify result correctness on CPU
# = 1 : verify correctness on CPU (may take a long time)
# environment tag : a string describing the specifics of your test environment
# branch name : name of the branch in git repo (git status | grep -e 'On branch')
# node name : $hostname
# arch : GPU architecture, e.g. "gfx9" or "gfx1100"
#get the command line arguments:
export verify=$1
echo 'Verification: ' $verify
export env_type=$2
echo 'Environment type: ' $env_type
export branch=$3
echo 'Branch name: ' $branch
export host_name=$4
echo 'Host name: ' $host_name
export arch=$5
echo 'GPU architecture: ' $arch
function print_log_header(){
rm -f $1;
echo 'On branch ' $3 &> $1;
echo 'Node name: ' $4 >> $1;
#get GPU_arch and number of compute units from rocminfo
echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1;
rocminfo | grep "Compute Unit:" >> $1;
hipcc --version | grep -e 'HIP version' >> $1;
echo 'Environment type: ' $2 >> $1;
/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1;
}
#run ONNX gemm tests
export onnx_log="perf_onnx_gemm_$arch.log"
print_log_header $onnx_log $env_type $branch $host_name
./profile_onnx_gemm.sh gemm 0 0 $verify 1 0 1 2>&1 | tee -a $onnx_log
./profile_onnx_gemm.sh gemm 1 0 $verify 1 0 1 2>&1 | tee -a $onnx_log
#!/bin/bash #!/bin/bash
# #
# in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/ # in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/
# run the script as "./run_performance_tests.sh <verification> <tag for your test environment> <branch name> < node name> # run the script as "./run_performance_tests.sh <verification> <tag for your test environment> <branch name> <node name>
# input arguments: # input arguments:
# verification = 0 : do not verify result correctness on CPU # verification = 0 : do not verify result correctness on CPU
# = 1 : verify correctness on CPU (may take a long time) # = 1 : verify correctness on CPU (may take a long time)
...@@ -51,20 +51,11 @@ print_log_header $gemm_log $env_type $branch $host_name ...@@ -51,20 +51,11 @@ print_log_header $gemm_log $env_type $branch $host_name
./profile_gemm.sh gemm 2 3 $verify 1 0 1 | tee -a $gemm_log ./profile_gemm.sh gemm 2 3 $verify 1 0 1 | tee -a $gemm_log
./profile_gemm.sh gemm 3 3 $verify 1 0 1 | tee -a $gemm_log ./profile_gemm.sh gemm 3 3 $verify 1 0 1 | tee -a $gemm_log
#run grouped_fwd fp16 tests #run ONNX gemm tests
export grouped_conv_fwd_log="perf_grouped_conv_fwd_fp16.log" export onnx_log="perf_onnx_gemm.log"
print_log_header $conv_fwd_log $env_type $branch $host_name print_log_header $onnx_log $env_type $branch $host_name
./profile_grouped_conv_fwd.sh grouped_conv_fwd 1 1 0 $verify 1 0 1 256 2>&1 | tee -a $grouped_conv_fwd_log ./profile_onnx_gemm.sh gemm 0 0 $verify 1 0 1 2>&1 | tee -a $onnx_log
./profile_onnx_gemm.sh gemm 1 0 $verify 1 0 1 2>&1 | tee -a $onnx_log
#run grouped_bwd_data fp16 tests
export grouped_conv_bwd_data_log="perf_grouped_conv_bwd_data_fp16.log"
print_log_header $grouped_conv_bwd_data_log $env_type $branch $host_name
./profile_grouped_conv_bwd_data.sh grouped_conv_bwd_data 1 1 $verify 1 0 1 256 2>&1 | tee -a $grouped_conv_bwd_data_log
#run grouped_bwd_weight fp16 tests
export grouped_conv_bwd_weight_log="perf_grouped_conv_bwd_weight_fp16.log"
print_log_header $grouped_conv_bwd_weight_log $env_type $branch $host_name
./profile_grouped_conv_bwd_weight.sh grouped_conv_bwd_weight 1 1 $verify 1 0 1 256 1 2>&1 | tee -a $grouped_conv_bwd_weight_log
#run resnet50 tests #run resnet50 tests
export resnet256_log="perf_resnet50_N256.log" export resnet256_log="perf_resnet50_N256.log"
......
...@@ -10,33 +10,27 @@ ...@@ -10,33 +10,27 @@
using F16 = ck_tile::half_t; using F16 = ck_tile::half_t;
using F32 = float; using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
static constexpr auto Intrawave = ck_tile::GemmPipelineScheduler::Intrawave; using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
static constexpr auto Interwave = ck_tile::GemmPipelineScheduler::Interwave; ck_tile::GemmPipelineScheduler::Intrawave>;
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
template <typename Tuple> ck_tile::GemmPipelineScheduler::Interwave>;
class TestCkTileGemmMemPipelineIntrawave : public TestCkTileGemmMemPipeline<Tuple, Intrawave>
{
};
template <typename Tuple>
class TestCkTileGemmMemPipelineInterwave : public TestCkTileGemmMemPipeline<Tuple, Interwave>
{
};
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler
std::tuple< Row, Col, Row, F16, F16, F32, F16>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Col, Row, Row, F16, F16, F32, F16>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave>,
std::tuple< Row, Row, Row, F16, F16, F32, F16>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Col, Col, Row, F16, F16, F32, F16> std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave>
>; >;
// clang-format on // clang-format on
TYPED_TEST_SUITE(TestCkTileGemmMemPipelineIntrawave, KernelTypes); TYPED_TEST_SUITE(TestCkTileGemmMemPipeline, KernelTypes);
TYPED_TEST_SUITE(TestCkTileGemmMemPipelineInterwave, KernelTypes);
#include "test_gemm_mem_pipeline_ut_cases.inc" #include "test_gemm_mem_pipeline_ut_cases.inc"
This diff is collapsed.
add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_xdl_wmma.cpp) add_gtest_executable(test_grouped_convnd_bwd_data_xdl test_grouped_convnd_bwd_data_xdl.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) target_link_libraries(test_grouped_convnd_bwd_data_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
endif()
add_gtest_executable(test_grouped_convnd_bwd_data_wmma test_grouped_convnd_bwd_data_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_data_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
endif() endif()
add_gtest_executable(test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp) add_gtest_executable(test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp)
if(result EQUAL 0) if(result EQUAL 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment