Unverified Commit 566b6480 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Code clean-up (#1285)

* code clean-up

* remove the profiling output samples
parent fcba889e
...@@ -202,7 +202,7 @@ endif() ...@@ -202,7 +202,7 @@ endif()
option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
option(USE_OPT_NAVI3X "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF)
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
...@@ -210,10 +210,10 @@ if(USE_BITINT_EXTENSION_INT4) ...@@ -210,10 +210,10 @@ if(USE_BITINT_EXTENSION_INT4)
message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
endif() endif()
if(USE_OPT_NAVI3X) if(USE_OPT_GFX11)
add_compile_options(-mcumode) add_compile_options(-mcumode)
add_compile_options(-mno-wavefrontsize64) add_compile_options(-mno-wavefrontsize64)
message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}") message("CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}")
endif() endif()
## Threads ## Threads
......
...@@ -515,30 +515,25 @@ def Build_CK(Map conf=[:]){ ...@@ -515,30 +515,25 @@ def Build_CK(Map conf=[:]){
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 24, unit: 'HOURS') timeout(time: 24, unit: 'HOURS')
{ {
//check whether running on Navi or MI300 node //check whether to run performance tests on this node
def navi_node = 0 def do_perf_tests = 0
def mi300_node = 0
sh 'rocminfo | tee rocminfo.log' sh 'rocminfo | tee rocminfo.log'
if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') ){ if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') || runShell('grep -n "gfx942" rocminfo.log') ){
navi_node = 1 do_perf_tests = 1
echo "This is a Navi node" echo "Stash profiler and run performance tests"
}
if ( runShell('grep -n "gfx942" rocminfo.log') ){
mi300_node = 1
echo "This is MI300 node"
} }
cmake_build(conf) cmake_build(conf)
dir("build"){ dir("build"){
//run tests and examples //run tests and examples
sh 'make -j check' sh 'make -j check'
if (params.RUN_PERFORMANCE_TESTS && navi_node == 0 && mi300_node == 0 ){ if (params.RUN_PERFORMANCE_TESTS && do_perf_tests == 0 ){
//we only need the ckProfiler to run the performance tests, so we pack and stash it //we only need the ckProfiler to run the performance tests, so we pack and stash it
//do not stash profiler on Navi or MI300 nodes //do not stash profiler on nodes where we don't need to run performance tests
sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler'
stash name: "ckProfiler.tar.gz" stash name: "ckProfiler.tar.gz"
} }
if (params.RUN_FULL_QA && mi300_node == 0 ){ if (params.RUN_FULL_QA && do_perf_tests == 0 ){
// build deb packages for all MI100/200/300 targets and prepare to export // build deb packages for all gfx9 targets and prepare to export
sh 'make -j package' sh 'make -j package'
archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb'
archiveArtifacts artifacts: 'composablekernel-tests_*.deb' archiveArtifacts artifacts: 'composablekernel-tests_*.deb'
...@@ -546,7 +541,7 @@ def Build_CK(Map conf=[:]){ ...@@ -546,7 +541,7 @@ def Build_CK(Map conf=[:]){
stash name: "ckprofiler_0.2.0_amd64.deb" stash name: "ckprofiler_0.2.0_amd64.deb"
} }
} }
if (params.hipTensor_test && navi_node == 0 ){ if (params.hipTensor_test && do_perf_tests == 0 ){
//build and test hipTensor //build and test hipTensor
sh """#!/bin/bash sh """#!/bin/bash
rm -rf "${params.hipTensor_branch}".zip rm -rf "${params.hipTensor_branch}".zip
...@@ -814,7 +809,7 @@ pipeline { ...@@ -814,7 +809,7 @@ pipeline {
{ {
parallel parallel
{ {
stage("Run Codegen Tests on MI200") stage("Run Codegen Tests on gfx90a")
{ {
when { when {
beforeAgent true beforeAgent true
...@@ -865,7 +860,7 @@ pipeline { ...@@ -865,7 +860,7 @@ pipeline {
cleanWs() cleanWs()
} }
} }
stage("Build CK and run Tests on MI300") stage("Build CK and run Tests on gfx942")
{ {
when { when {
beforeAgent true beforeAgent true
...@@ -885,7 +880,7 @@ pipeline { ...@@ -885,7 +880,7 @@ pipeline {
cleanWs() cleanWs()
} }
} }
stage("Build CK and run Tests on MI200") stage("Build CK and run Tests on gfx90a")
{ {
when { when {
beforeAgent true beforeAgent true
...@@ -925,13 +920,13 @@ pipeline { ...@@ -925,13 +920,13 @@ pipeline {
cleanWs() cleanWs()
} }
} }
stage("Build CK and run Tests on Navi21") stage("Build CK and run Tests on gfx1030")
{ {
when { when {
beforeAgent true beforeAgent true
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
} }
agent{ label rocmnode("navi21") } agent{ label rocmnode("gfx1030") }
environment{ environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
...@@ -945,13 +940,13 @@ pipeline { ...@@ -945,13 +940,13 @@ pipeline {
cleanWs() cleanWs()
} }
} }
stage("Build CK and run Tests on Navi32") stage("Build CK and run Tests on gfx1101")
{ {
when { when {
beforeAgent true beforeAgent true
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
} }
agent{ label rocmnode("navi32") } agent{ label rocmnode("gfx1101") }
environment{ environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
......
...@@ -181,4 +181,3 @@ int main(int argc, char* argv[]) ...@@ -181,4 +181,3 @@ int main(int argc, char* argv[])
{1, 1, 1} /*filter_dilations*/); {1, 1, 1} /*filter_dilations*/);
return 0; return 0;
} }
// MI100 Perf: 0.255178 ms, 1698.9 GB/s,
...@@ -7,17 +7,3 @@ ...@@ -7,17 +7,3 @@
#arg3: run kernel # of times (>1) #arg3: run kernel # of times (>1)
./bin/example_gemm_xdl 0 1 5 ./bin/example_gemm_xdl 0 1 5
``` ```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
arg.a_grid_desc_k0_m_k1_{512, 3840, 8}
arg.b_grid_desc_k0_n_k1_{512, 4096, 8}
arg.c_grid_desc_m_n_{ 3840, 4096}
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 5 times...
Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s
```
...@@ -9,20 +9,3 @@ ...@@ -9,20 +9,3 @@
#arg11 to 12: alpha, beta #arg11 to 12: alpha, beta
./bin/example_gemm_bilinear_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 4096 0.5 0.5 ./bin/example_gemm_bilinear_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 4096 0.5 0.5
``` ```
Result (MI100 @ 1502Mhz, 184.6TFlops peak FP16)
```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
arg.a_grid_desc_k0_m_k1_{512, 3840, 8}
arg.b_grid_desc_k0_n_k1_{512, 4096, 8}
arg.c0_grid_desc_m_n_{ 3840, 4096}
arg.c_grid_desc_m_n_{ 3840, 4096}
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
Perf: 0.936965 ms, 137.517 TFlops, 102.959 GB/s
error: 0
max_diff: 0, 558.5, 558.5
```
...@@ -8,16 +8,3 @@ ...@@ -8,16 +8,3 @@
#arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE" #arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE"
./bin/example_gemm_add_add_fastgelu_xdl_fp16 1 1 1 ./bin/example_gemm_add_add_fastgelu_xdl_fp16 1 1 1
``` ```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1}
d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 1.26914 ms, 101.525 TFlops, 100.804 GB/s, DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8>
```
...@@ -16,17 +16,3 @@ ...@@ -16,17 +16,3 @@
# <right padding>, (ie RightPy, RightPx for 2D) # <right padding>, (ie RightPy, RightPx for 2D)
./bin/example_convnd_fwd_xdl 0 1 100 ./bin/example_convnd_fwd_xdl 0 1 100
``` ```
Result (MI100 @ 1087Mhz, 33.4TFlops peak FP32)
```
input: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192}
weights: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192}
output: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256}
arg.a_grid_desc_k0_m_k1_{432, 165888, 4}
arg.b_grid_desc_k0_n_k1_{432, 256, 4}
arg.c_grid_desc_m_n_{ 165888, 256}
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 100 times...
Perf: 4.43736 ms, 33.0753 TFlops, 150.357 GB/s
```
...@@ -7,19 +7,3 @@ ...@@ -7,19 +7,3 @@
#arg3: run kernel # of times (>1) #arg3: run kernel # of times (>1)
./bin/example_grouped_gemm_xdl_fp16 0 1 5 ./bin/example_grouped_gemm_xdl_fp16 0 1 5
``` ```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
```
gemm[0] a_m_k: dim 2, lengths {256, 64}, strides {64, 1} b_k_n: dim 2, lengths {64, 128}, strides {1, 64} c_m_n: dim 2, lengths {256, 128}, strides {128, 1}
gemm[1] a_m_k: dim 2, lengths {512, 128}, strides {128, 1} b_k_n: dim 2, lengths {128, 256}, strides {1, 128} c_m_n: dim 2, lengths {512, 256}, strides {256, 1}
gemm[2] a_m_k: dim 2, lengths {768, 192}, strides {192, 1} b_k_n: dim 2, lengths {192, 384}, strides {1, 192} c_m_n: dim 2, lengths {768, 384}, strides {384, 1}
gemm[3] a_m_k: dim 2, lengths {1024, 256}, strides {256, 1} b_k_n: dim 2, lengths {256, 512}, strides {1, 256} c_m_n: dim 2, lengths {1024, 512}, strides {512, 1}
group: 0 arg.a_grid_desc_k0_m_k1_{8, 256, 8}, arg.b_grid_desc_k0_n_k1_{8, 128, 8}, arg.c_grid_desc_m_n_{ 256, 128}
group: 1 arg.a_grid_desc_k0_m_k1_{16, 512, 8}, arg.b_grid_desc_k0_n_k1_{16, 256, 8}, arg.c_grid_desc_m_n_{ 512, 256}
group: 2 arg.a_grid_desc_k0_m_k1_{24, 768, 8}, arg.b_grid_desc_k0_n_k1_{24, 384, 8}, arg.c_grid_desc_m_n_{ 768, 384}
group: 3 arg.a_grid_desc_k0_m_k1_{32, 1024, 8}, arg.b_grid_desc_k0_n_k1_{32, 512, 8}, arg.c_grid_desc_m_n_{ 1024, 512}
launch_and_time_kernel: grid_dim {30, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 5 times...
Perf: 0.037887 ms, 11.0706 TFlops, 90.8132 GB/s, DeviceGroupedGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2>
```
...@@ -7,14 +7,3 @@ ...@@ -7,14 +7,3 @@
#arg3: time kernel (0=no, 1=yes) #arg3: time kernel (0=no, 1=yes)
./bin/example_contraction_bilinear_xdl_fp32 1 1 1 ./bin/example_contraction_bilinear_xdl_fp32 1 1 1
``` ```
Result (MI100 @ dynammic freq, 46TFlops peak FP32)
```
a_ms_ks: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1}
b_ks_ns: dim 4, lengths {32, 64, 32, 64}, strides {128, 1, 524288, 4096}
c_ms_ns: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1}
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContractionMultipleD_Xdl_CShuffle<256, 256, 128, 16, 4, 4>
```
...@@ -16,15 +16,3 @@ Following arguments (depending on number of spatial dims): ...@@ -16,15 +16,3 @@ Following arguments (depending on number of spatial dims):
./bin/example_grouped_conv_fwd_bias_relu_add_xdl_fp16 1 1 1 ./bin/example_grouped_conv_fwd_bias_relu_add_xdl_fp16 1 1 1
``` ```
Result (MI100)
```
in: dim 5, lengths {1, 128, 192, 71, 71}, strides {192, 967872, 1, 13632, 192}
wei: dim 5, lengths {1, 256, 192, 3, 3}, strides {442368, 1728, 1, 576, 192}
bias: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 0, 1, 0, 0}
residual: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 0, 1, 0, 0}
out: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 331776, 1, 9216, 256}
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 1.55981 ms, 94.0927 TFlops, 213.868 GB/s, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 16, Default>
```
...@@ -8,19 +8,3 @@ ...@@ -8,19 +8,3 @@
#arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE" #arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE"
./bin/example_gemm_add_multiply_dl_fp16 1 1 1 ./bin/example_gemm_add_multiply_dl_fp16 1 1 1
``` ```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1}
d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2}
arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2}
arg.e_grid_desc_m_n_{ 3840, 4096}
launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 3.99904 ms, 32.22 TFlops, 31.9913 GB/s, DeviceGemmMultipleD_Dl<256, 128, 128, 16, 2, 4, 4, 1>
```
...@@ -236,7 +236,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -236,7 +236,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#ifndef CK_WORKAROUND_DENORM_FIX #ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0 #define CK_WORKAROUND_DENORM_FIX 0
#else #else
// enable only on MI200 // enable only for gfx90a
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) #define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#endif // CK_WORKAROUND_DENORM_FIX #endif // CK_WORKAROUND_DENORM_FIX
......
...@@ -65,20 +65,20 @@ inline bool is_lds_direct_load_supported() ...@@ -65,20 +65,20 @@ inline bool is_lds_direct_load_supported()
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
} }
inline bool is_navi1_supported() inline bool is_gfx101_supported()
{ {
return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" || return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" ||
ck::get_device_name() == "gfx1012"; ck::get_device_name() == "gfx1012";
} }
inline bool is_navi2_supported() inline bool is_gfx103_supported()
{ {
return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" || return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" ||
ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" || ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" ||
ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036"; ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036";
} }
inline bool is_navi3_supported() inline bool is_gfx11_supported()
{ {
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
......
...@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -648,7 +648,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -648,7 +648,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported()) ck::is_gfx103_supported() || ck::is_gfx11_supported())
{ {
bool pass = true; bool pass = true;
pass = pass && arg.K_ % K1 == 0; pass = pass && arg.K_ % K1 == 0;
......
...@@ -858,7 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -858,7 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
...@@ -1435,7 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -1435,7 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
#if 0 #if 0
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
...@@ -1392,8 +1392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1392,8 +1392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() || if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_navi3_supported())) ck::is_gfx11_supported()))
{ {
return false; return false;
} }
......
...@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> || if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>)) is_same_v<AccDataType, int32_t>))
......
...@@ -535,8 +535,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -535,8 +535,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
} }
} }
if(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() || if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_navi3_supported()) ck::is_gfx11_supported())
{ {
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
......
...@@ -168,7 +168,7 @@ struct DeviceGemmDpp : public DeviceGemm<ALayout, ...@@ -168,7 +168,7 @@ struct DeviceGemmDpp : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& karg) static bool IsSupportedArgument(const Argument& karg)
{ {
if(ck::is_navi2_supported() || ck::is_navi3_supported()) if(ck::is_gfx103_supported() || ck::is_gfx11_supported())
{ {
return GridwiseGemm::CheckValidity(karg); return GridwiseGemm::CheckValidity(karg);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment