Unverified Commit 3049b546 authored by Haocong WANG's avatar Haocong WANG Committed by GitHub
Browse files

[GEMM] gemm_universal related optimization (#1453)



* replace buffer_atomic with global_atomic

* fixed global_atomic_add

* added bf16 atomic_add

* format

* clang-format-12

* clean

* clean

* add guards

* Update gtest.cmake

* enabled splitk_gemm_multi_d

* format

* add ckProfiler

* format

* fixed naming

* format

* clean

* clean

* add guards

* fix clang format

* format

* add kbatch printout

* clean

* Add rocm6.2 related gemm optimization

* Limit bf16 atomic usage

* remove redundant RCR gemm_universal instance

* Add RRR fp8 gemm universal instance

* Bug fix

* Add GPU_TARGET guard to FP8/BF8 target

* bug fix

* update cmake

* remove all fp8/bf8 example if arch not support

* Enable fp8 RRR support in ckProfiler

* limit greedy-reverse flag to gemm_universal in ckProfiler

---------
Co-authored-by: default avatarJing Zhang <jizhan@fb.com>
Co-authored-by: default avatarJing Zhang <jizhan@meta.com>
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
Co-authored-by: default avatarIllia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: default avatarillsilin <Illia.Silin@amd.com>
parent 50c42348
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_instances<Interwave, GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_instances<Interwave, GemmNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -45,8 +45,7 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std::tuple< ...@@ -45,8 +45,7 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
...@@ -72,7 +71,11 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple< ...@@ -72,7 +71,11 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple<
// Latency friendly // Latency friendly
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 512, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
// Memory friendly // Memory friendly
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
...@@ -83,7 +86,11 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple< ...@@ -83,7 +86,11 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 512, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 128, 16, 16, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 128, 16, 16, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
......
...@@ -15,7 +15,7 @@ set(GROUPED_CONV3D_BWD_DATA ...@@ -15,7 +15,7 @@ set(GROUPED_CONV3D_BWD_DATA
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp) wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
list(APPEND GROUPED_CONV3D_BWD_DATA list(APPEND GROUPED_CONV3D_BWD_DATA
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp) xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp)
endif() endif()
......
...@@ -30,7 +30,7 @@ list(APPEND GROUPED_CONV3D_BWD_WEIGHT ...@@ -30,7 +30,7 @@ list(APPEND GROUPED_CONV3D_BWD_WEIGHT
wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp) wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
list(APPEND GROUPED_CONV3D_BWD_WEIGHT list(APPEND GROUPED_CONV3D_BWD_WEIGHT
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
endif() endif()
......
...@@ -4,7 +4,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT_BILINEAR ...@@ -4,7 +4,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT_BILINEAR
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
list(APPEND GROUPED_CONV3D_BWD_WEIGHT_BILINEAR list(APPEND GROUPED_CONV3D_BWD_WEIGHT_BILINEAR
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
endif() endif()
......
...@@ -4,7 +4,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT_SCALE ...@@ -4,7 +4,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT_SCALE
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
list(APPEND GROUPED_CONV3D_BWD_WEIGHT_SCALE list(APPEND GROUPED_CONV3D_BWD_WEIGHT_SCALE
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
endif() endif()
......
...@@ -43,22 +43,22 @@ set(GROUPED_CONV3D_FWD ...@@ -43,22 +43,22 @@ set(GROUPED_CONV3D_FWD
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp
) )
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
list(APPEND GROUPED_CONV3D_FWD list(APPEND GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp) xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp)
endif() endif()
if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
list(APPEND GROUPED_CONV3D_FWD list(APPEND GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp) xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp)
endif() endif()
if(DTYPES MATCHES "bf8" OR NOT DEFINED DTYPES) if((DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
list(APPEND GROUPED_CONV3D_FWD list(APPEND GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp) xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp)
endif() endif()
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
list(APPEND GROUPED_CONV3D_FWD list(APPEND GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp) xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp)
list(APPEND GROUPED_CONV3D_FWD list(APPEND GROUPED_CONV3D_FWD
......
...@@ -48,6 +48,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification, ...@@ -48,6 +48,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
int StrideD0, int StrideD0,
int StrideD1, int StrideD1,
int StrideE, int StrideE,
int KBatch,
int n_warmup, int n_warmup,
int n_iter, int n_iter,
uint64_t rotating = 0) uint64_t rotating = 0)
...@@ -129,17 +130,17 @@ bool profile_gemm_multiply_multiply_impl(int do_verification, ...@@ -129,17 +130,17 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
d1_device_buf.ToDevice(d1_m_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data());
using DeviceOp = using DeviceOp =
ck::tensor_operation::device::DeviceGemmMultipleD<ALayout, ck::tensor_operation::device::DeviceGemmMultipleDSplitK<ALayout,
BLayout, BLayout,
ck::Tuple<D0Layout, D1Layout>, ck::Tuple<D0Layout, D1Layout>,
ELayout, ELayout,
ADataType, ADataType,
BDataType, BDataType,
ck::Tuple<D0DataType, D1DataType>, ck::Tuple<D0DataType, D1DataType>,
EDataType, EDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CElementOp>; CElementOp>;
// get device op instances // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
...@@ -182,104 +183,128 @@ bool profile_gemm_multiply_multiply_impl(int do_verification, ...@@ -182,104 +183,128 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
float best_kbatch = 0;
// profile device GEMM instances // profile device GEMM instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38};
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 2>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 2>{StrideD0, StrideD1},
StrideE,
a_element_op,
b_element_op,
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0, n_warmup, n_iter}); if(KBatch > 0)
{
kbatch_list = {KBatch};
}
if(do_verification) for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 2>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 2>{StrideD0, StrideD1},
StrideE,
kbatch_curr,
a_element_op,
b_element_op,
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
c_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); // re-init C to zero before profiling next kernel
c_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_log) if(do_verification)
{ {
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl; c_device_buf.FromDevice(e_m_n_device_result.mData.data());
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",") pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
<< std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",") if(do_log)
<< std::endl; {
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", e_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", e_m_n_device_result.mData, ",")
<< std::endl;
}
} }
}
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run( float ave_time = invoker_ptr->Run(argument_ptr.get(),
argument_ptr.get(), StreamConfig{nullptr,
StreamConfig{ time_kernel,
nullptr, time_kernel, 0, n_warmup, n_iter, rotating_count > 1, rotating_count}); 0,
n_warmup,
n_iter,
rotating_count > 1,
rotating_count});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< gb_per_sec << " GB/s, " << op_name << std::endl; << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl;
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8
// set softer tolerances for fp8 // set softer tolerances for fp8
if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> || if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<EDataType, f8_t>) is_same_v<EDataType, f8_t>)
{ {
std::string msg = "Error: Incorrect results!"; std::string msg = "Error: Incorrect results!";
double rtol = 1e-1; double rtol = 1e-1;
double atol = 1e-1; double atol = 1e-1;
pass = pass & ck::utils::check_err( pass = pass & ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol); e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
} }
else else
{ {
#endif #endif
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8
} }
#endif #endif
if(tflops > best_tflops) if(tflops > best_tflops && ave_time > 1e-10)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
}
else
{ {
best_op_name = op_name; std::cout << op_ptr->GetTypeString() << " does not support this problem"
best_tflops = tflops; << std::endl;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
} }
} }
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
} }
if constexpr(is_same<EDataType, float>::value) if constexpr(is_same<EDataType, float>::value)
...@@ -318,9 +343,9 @@ bool profile_gemm_multiply_multiply_impl(int do_verification, ...@@ -318,9 +343,9 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
} }
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideE = " << StrideE << " : " << best_ave_time << " StrideB = " << StrideB << " StrideE = " << StrideE << " KBatch = " << best_kbatch
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< best_op_name << std::endl; << " GB/s, " << best_op_name << std::endl;
return pass; return pass;
} }
......
...@@ -152,7 +152,7 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -152,7 +152,7 @@ bool profile_gemm_universal_impl(int do_verification,
// profile device GEMM instances // profile device GEMM instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 19, 20, 32, 38}; std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38};
if(KBatch > 0) if(KBatch > 0)
{ {
...@@ -249,7 +249,7 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -249,7 +249,7 @@ bool profile_gemm_universal_impl(int do_verification,
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl; << kbatch_curr << std::endl;
if(tflops > best_tflops) if(tflops > best_tflops && ave_time > 1e-10)
{ {
best_op_name = op_name; best_op_name = op_name;
best_tflops = tflops; best_tflops = tflops;
......
...@@ -34,7 +34,7 @@ enum struct GemmDataType ...@@ -34,7 +34,7 @@ enum struct GemmDataType
int profile_gemm_multiply_multiply(int argc, char* argv[]) int profile_gemm_multiply_multiply(int argc, char* argv[])
{ {
if(argc != 16 && argc != 19) if(argc != 16 && argc != 20)
{ {
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: " printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
...@@ -50,9 +50,10 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) ...@@ -50,9 +50,10 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
printf("arg7: time kernel (0=no, 1=yes)\n"); printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 15: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); printf("arg8 to 15: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n");
printf("optional:\n"); printf("optional:\n");
printf("arg16: number of warm-up cycles (default 1)\n"); printf("arg16: number of kbatch (default 1)\n");
printf("arg17: number of iterations (default 10)\n"); printf("arg17: number of warm-up cycles (default 1)\n");
printf("arg18: memory for rotating buffer (default 0, size in MB)\n"); printf("arg18: number of iterations (default 10)\n");
printf("arg19: memory for rotating buffer (default 0, size in MB)\n");
exit(1); exit(1);
} }
...@@ -76,11 +77,13 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) ...@@ -76,11 +77,13 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
int n_warmup = 1; int n_warmup = 1;
int n_iter = 10; int n_iter = 10;
uint64_t rotating = 0; uint64_t rotating = 0;
if(argc == 19) int KBatch = 1;
if(argc == 20)
{ {
n_warmup = std::stoi(argv[16]); KBatch = std::stoi(argv[16]);
n_iter = std::stoi(argv[17]); n_warmup = std::stoi(argv[17]);
rotating = std::stoull(argv[18]) * 1024 * 1024; n_iter = std::stoi(argv[18]);
rotating = std::stoull(argv[19]) * 1024 * 1024;
} }
using F32 = float; using F32 = float;
...@@ -146,6 +149,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) ...@@ -146,6 +149,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
(StrideD0 < 0) ? DefaultStrideD0 : StrideD0, (StrideD0 < 0) ? DefaultStrideD0 : StrideD0,
(StrideD1 < 0) ? DefaultStrideD1 : StrideD1, (StrideD1 < 0) ? DefaultStrideD1 : StrideD1,
(StrideE < 0) ? DefaultStrideE : StrideE, (StrideE < 0) ? DefaultStrideE : StrideE,
KBatch,
n_warmup, n_warmup,
n_iter, n_iter,
rotating); rotating);
......
...@@ -171,6 +171,10 @@ int profile_gemm_universal(int argc, char* argv[]) ...@@ -171,6 +171,10 @@ int profile_gemm_universal(int argc, char* argv[])
{ {
return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{}); return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{});
} }
else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{}); return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{});
......
...@@ -85,9 +85,11 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -85,9 +85,11 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
const auto StrideCs = argToIntArray(argv[13]); const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1; const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using F8 = ck::f8_t; #if defined(CK_ENABLE_FP8)
using F8 = ck::f8_t;
#endif
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using I8 = int8_t; using I8 = int8_t;
......
...@@ -44,17 +44,22 @@ class TestGemmUniversal_MK_NK ...@@ -44,17 +44,22 @@ class TestGemmUniversal_MK_NK
using KernelTypes_MK_KN = ::testing::Types< using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType // ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>, std::tuple< F16, F16, F16, F16>,
#if (defined CK_ENABLE_FP8)
std::tuple< F16, F8, F16, F16>, std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>, std::tuple< F8, F16, F16, F16>,
std::tuple< F8, F8, F8, BF16>,
#endif
std::tuple< BF16, BF16, BF16, BF16> std::tuple< BF16, BF16, BF16, BF16>
>; >;
using KernelTypes_MK_NK = ::testing::Types< using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType // ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>, std::tuple< F16, F16, F16, F16>,
#if (defined CK_ENABLE_FP8)
std::tuple< F16, F8, F16, F16>, std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>, std::tuple< F8, F16, F16, F16>,
std::tuple< BF16, BF16, BF16, BF16>, std::tuple< F8, F8, F8, BF16>,
std::tuple< F8, F8, F8, BF16> #endif
std::tuple< BF16, BF16, BF16, BF16>
>; >;
// clang-format on // clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment