"vscode:/vscode.git/clone" did not exist on "c50959259941f4dee8de9c63f2792c7a76ee9646"
Unverified Commit d9f1ead3 authored by Mingtao Gu's avatar Mingtao Gu Committed by GitHub
Browse files

Added Int4 mixed batch gemm support (#1839)



* remove redundant kernels.

* added batched_gemm_xdl_fp16int4_b_scale_v3

* Enabled the split K.

* added the batched_gemm_b_scale ckProfiler, meet function issue

* fix some typo

* fix ckProfiler build issue

* fix some bugs

* updated some debug info

* comment some code

* Fix

* fixed some bugs and refactor the code

* fixed a function bug.

* formatted files.

* formatted

* uncommented the ckProfiler CMakeLists

* fixed.

* fix ckProfiler for batched_gemm_b_scale

---------
Co-authored-by: default avatarmtgu0705 <mtgu@amd.com>
Co-authored-by: default avataraska-0096 <haocwang@amd.com>
Co-authored-by: default avatarBartlomiej Kocot <barkocot@amd.com>
parent a8c5bd9b
...@@ -22,3 +22,6 @@ if(USE_BITINT_EXTENSION_INT4) ...@@ -22,3 +22,6 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp)
add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4)
endif() endif()
add_example_executable(example_batched_gemm_xdl_fp16int4_b_scale_v3 batched_gemm_xdl_fp16int4_b_scale_v3.cpp)
add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16int4_b_scale_v3)
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = ck::pk_i4_t;
using BScaleDataType = ck::half_t;
using AccDataType = F32;
using CShuffleDataType = F16;
using CDataType = F16;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr ck::index_t Scale_Block_N = 1;
static constexpr ck::index_t Scale_Block_K = 128;
static constexpr ck::index_t KPerBlock = 256;
// clang-format off
using DeviceBatchedGemmV2Instance =
ck::tensor_operation::device::DeviceBatchedGemm_Xdl_CShuffleV3_BScale<
ALayout, BLayout, CLayout,
ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
256, Scale_Block_N, Scale_Block_K,
16, 64,
KPerBlock, 8, 32,
16, 16,
1, 1,
S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>;
// clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
AccDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
#include "run_batched_gemm_example_fp16int4_b_scale.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_fp16_int4_b_scale_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <random>
#pragma once
struct ProblemSize final
{
ck::index_t M = 128;
ck::index_t N = 128;
ck::index_t K = 384;
ck::index_t stride_A = K;
ck::index_t stride_B = K;
ck::index_t stride_C = N;
ck::index_t batch_stride_A = M * K;
ck::index_t batch_stride_B = K * N;
ck::index_t batch_stride_C = M * N;
// Batched Gemm count
ck::index_t batch_count = 2;
// Split K count
ck::index_t KBatch = 1;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
};
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count,
KBatch] = problem_size;
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz});
}
else
{
return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride});
}
};
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
ck::index_t batch_BScale_Stride =
((K + Scale_Block_K - 1) / Scale_Block_K) * ((N + Scale_Block_N - 1) / Scale_Block_N);
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{}));
Tensor<BDataType> b_g_k_n_permute(
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{}));
Tensor<BScaleDataType> b1_g_k_n(
f_host_tensor_descriptor(batch_count,
(K + Scale_Block_K - 1) / Scale_Block_K,
(N + Scale_Block_N - 1) / Scale_Block_N,
Scale_Stride_BN,
batch_BScale_Stride,
BLayout{}));
switch(config.init_method)
{
case 0:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_g_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
b1_g_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
break;
case 2:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
b1_g_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
break;
case 3:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_g_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
b1_g_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
break;
case 4:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_g_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
break;
case 5:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
b1_g_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.5, 0.5});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
}
Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{}));
Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "b1_g_k_n: " << b1_g_k_n.mDesc << std::endl;
std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl;
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem b1_g_scale_device_buf(sizeof(BScaleDataType) * b1_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_g_m_n_device_buf(sizeof(CDataType) *
c_g_m_n_device_result.mDesc.GetElementSpaceSize());
printf("a_g_m_k size: %zu, b_g_k_n size: %zu, b1_g_k_n size: %zu, c_g_m_n size: %zu\n",
a_g_m_k.mDesc.GetElementSpaceSize(),
b_g_k_n_permute.mDesc.GetElementSpaceSize(),
b1_g_k_n.mDesc.GetElementSpaceSize(),
c_g_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute
if constexpr(PermuteB)
{
printf("Permute B\n");
int K1 = KPerBlock;
int K0 = K / KPerBlock;
// int K0, N, K1
for(int bs = 0; bs < batch_count; bs++)
{
for(int j = 0; j < K0; j++)
{
for(int i = 0; i < N; i++)
{
for(int jj = 0; jj < K1; jj++)
{
b_g_k_n_permute(bs * batch_stride_B + j * N * K1 + i * K1 + jj) =
b_g_k_n(bs * batch_stride_B + i * K + (j * K1 + jj));
}
}
}
}
}
else
{
b_g_k_n_permute = b_g_k_n;
}
// vector pk_i4x4 permute
for(int bs = 0; bs < batch_count; bs++)
{
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int input[8];
for(int k = 0; k < 4; k++)
{
int i4x2 = b_g_k_n_permute(bs, j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
b_g_k_n_permute(bs, j + 0, i) = i4x2;
}
{
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_g_k_n_permute(bs, j + 2, i) = i4x2;
}
{
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_g_k_n_permute(bs, j + 4, i) = i4x2;
}
{
int hi = input[7];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_g_k_n_permute(bs, j + 6, i) = i4x2;
}
}
}
}
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
b_g_k_n_device_buf.ToDevice(b_g_k_n_permute.mData.data());
b1_g_scale_device_buf.ToDevice(b1_g_k_n.mData.data());
DeviceMem workspace;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceBatchedGemmV2Instance{};
auto invoker = gemm.MakeInvoker();
float ave_time = 0;
auto argument =
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
stride_A,
stride_B,
stride_C,
Scale_Stride_BN,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_BScale_Stride,
static_cast<BScaleDataType*>(b1_g_scale_device_buf.GetDeviceBuffer()),
batch_count, // batch count
KBatch, // split K count
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
bool pass = true;
Tensor<float> b_g_k_n_dequant({batch_count, K, N});
if(config.do_verification)
{
float v_b = 0;
for(int bs = 0; bs < batch_count; bs++)
{
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
ck::pk_i4_t i4x2 = b_g_k_n(bs, k, n).data;
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
i4 = i4 - 8;
v_b = ck::type_convert<float>(i4);
b_g_k_n_dequant(bs, k, n) =
ck::type_convert<float>(v_b) *
ck::type_convert<float>(b1_g_k_n(bs, k / Scale_Block_K, n / Scale_Block_N));
}
}
}
auto ref_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_g_m_k,
b_g_k_n_dequant,
c_g_m_n_host_result,
PassThrough{},
PassThrough{},
PassThrough{});
ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0});
hip_check_error(hipDeviceSynchronize());
c_g_m_n_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
pass &= ck::utils::check_err(c_g_m_n_device_result,
c_g_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
}
if(config.time_kernel)
{
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K +
sizeof(BDataType) * K * N /
(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
#if 0
// print A matrix
printf("A matrix:\n");
for(int bs = 0; bs < batch_count; bs++)
{
printf("batch %d -> Address: %p\n", bs, static_cast<void*>(&a_g_m_k(bs, 0, 0)));
for(int i = 0; i < M; i++)
{
for(int j = 0; j < K; j++)
{
printf("%.2f,", static_cast<float>(a_g_m_k(bs, i, j)));
}
printf("\n");
}
}
// print B matrix original
printf("B matrix original:\n");
for(int bs = 0; bs < batch_count; bs++)
{
printf("batch %d -> Address: %p\n", bs, static_cast<void*>(&b_g_k_n(bs, 0, 0)));
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
ck::pk_i4_t i4x2 = b_g_k_n(bs, k, n).data;
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
i4 = i4 - 8;
printf("%d,", static_cast<int>(i4));
}
printf("\n");
}
}
// print B matrix
printf("B matrix:\n");
for(int bs = 0; bs < batch_count; bs++)
{
printf("batch %d -> Address: %p\n", bs, static_cast<void*>(&b_g_k_n_dequant(bs, 0, 0)));
for(int i = 0; i < K; i++)
{
for(int j = 0; j < N; j++)
{
printf("%.2f, ", static_cast<float>(b_g_k_n_dequant(bs, i, j)));
}
printf("\n");
}
}
// print B scale matrix
printf("B Scale matrix:\n");
for(int bs = 0; bs < batch_count; bs++)
{
printf("batch %d -> Address: %p\n", bs, static_cast<void*>(&b1_g_k_n(bs, 0, 0)));
for(int i = 0; i < (K + Scale_Block_K - 1) / Scale_Block_K; i++)
{
for(int j = 0; j < (N + Scale_Block_N - 1) / Scale_Block_N; j++)
{
printf("%.2f, ", static_cast<float>(b1_g_k_n(bs, i, j)));
}
printf("\n");
}
}
// print C matrix
printf("C matrix:\n");
for(int bs = 0; bs < batch_count; bs++)
{
printf(
"batch %d -> Address: %p\n", bs, static_cast<void*>(&c_g_m_n_device_result(bs, 0, 0)));
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
printf("%.2f, ", static_cast<float>(c_g_m_n_device_result(bs, i, j)));
}
printf("\n");
}
}
printf("C reference matrix:\n");
for(int bs = 0; bs < batch_count; bs++)
{
printf("batch %d -> Address: %p\n", bs, static_cast<void*>(&c_g_m_n_host_result(bs, 0, 0)));
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
printf("%.2f, ", static_cast<float>(c_g_m_n_host_result(bs, i, j)));
}
printf("\n");
}
}
#endif
return pass;
}
bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
std::mt19937 gen(11939);
std::uniform_int_distribution<int> dis(0, 15);
problem_size.M = 128 * (dis(gen) + 1);
problem_size.N = 128 * (dis(gen) + 1);
problem_size.K = 256 * (dis(gen) + 2);
problem_size.batch_count = 2;
if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc >= 7)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.M = std::stoi(argv[4]);
problem_size.N = std::stoi(argv[5]);
problem_size.K = std::stoi(argv[6]);
if(argc >= 8)
{
problem_size.batch_count = std::stoi(argv[7]);
}
if(argc >= 9)
{
problem_size.KBatch = std::stoi(argv[8]);
}
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
problem_size.stride_A = problem_size.K;
problem_size.stride_B = problem_size.K;
problem_size.stride_C = problem_size.N;
problem_size.batch_stride_A = problem_size.M * problem_size.K;
problem_size.batch_stride_B = problem_size.K * problem_size.N;
problem_size.batch_stride_C = problem_size.M * problem_size.N;
return run_batched_gemm(problem_size, config);
}
...@@ -44,6 +44,48 @@ struct DeviceBatchedGemm : public BaseOperator ...@@ -44,6 +44,48 @@ struct DeviceBatchedGemm : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename BScaleType,
typename CDataType,
index_t ScaleBlockN,
index_t ScaleBlockK,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemmV2BScale : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t StrideScaleB,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB,
ck::index_t BatchStrideC,
ck::index_t BatchStrideScaleB,
const void* p_b_scale,
ck::index_t Batch,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual bool GetPermuteB() = 0;
virtual ck::index_t GetKPerBlock() = 0;
};
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
......
...@@ -37,7 +37,7 @@ __global__ void ...@@ -37,7 +37,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
...@@ -70,7 +70,7 @@ __global__ void ...@@ -70,7 +70,7 @@ __global__ void
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
...@@ -638,45 +638,45 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -638,45 +638,45 @@ struct GridwiseGemm_xdl_cshuffle_v3
struct SplitKBatchOffset struct SplitKBatchOffset
{ {
__device__ SplitKBatchOffset(Argument& karg) __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{ {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; a_k_split_offset = k_id * karg.KRead / APackedSize;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; a_k_split_offset = k_id * karg.KRead * karg.StrideA;
} }
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{ {
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; b_k_split_offset = k_id * karg.KRead * karg.StrideB;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
if constexpr(!PermuteB) if constexpr(!PermuteB)
{ {
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; b_k_split_offset = k_id * karg.KRead / BPackedSize;
} }
else else
{ {
const int k0_offset = karg.KRead * karg.N; const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; b_k_split_offset = k_id * k0_offset / BPackedSize;
} }
} }
// Calculate B scale offset // Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{ {
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB; scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK); scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
} }
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1)) if(k_id < (karg.KBatch - 1))
{ {
karg.K = karg.KRead; karg.K = karg.KRead;
} }
...@@ -687,7 +687,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -687,7 +687,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
if(karg.IsReduceAdd()) if(karg.IsReduceAdd())
{ {
c_reduce_offset = blockIdx.z * karg.M * karg.N; c_reduce_offset = k_id * karg.M * karg.N;
} }
else else
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <memory>
#include <vector>
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
void add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmV2BScale<Row,
Col,
Row,
F16,
I4,
F16,
F16,
1,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <typename ADataType,
typename BDataType,
typename BScaleDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
index_t ScaleBlockK>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatchedGemmV2BScale<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
BScaleDataType,
CDataType,
1,
ScaleBlockK,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceBatchedGemmV2BScale<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
BScaleDataType,
CDataType,
1,
ScaleBlockK,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, pk_i4_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
# ONLY XDL_KERNELS
set(BATCHED_GEMM_B_SCALE_INSTANCES)
list(APPEND BATCHED_GEMM_B_SCALE_INSTANCES
device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
)
set_source_files_properties(device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_batched_gemm_b_scale_instance ${BATCHED_GEMM_B_SCALE_INSTANCES})
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmV2BScale<Row,
Col,
Row,
F16,
I4,
F16,
F16,
1,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Intrawave,
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <iostream>
#include <typeinfo>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_b_scale.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename BScaleDataType,
typename ComputeDataType,
typename AccDataType,
typename CDataType,
index_t ScaleBlockK,
typename ALayout,
typename BLayout,
typename CLayout>
bool profile_batched_gemm_b_scale_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideC,
int BatchStrideA,
int BatchStrideB,
int BatchStrideC,
int BatchStrideScaleB,
int BatchSize,
int KBatch,
int n_warmup,
int n_iter,
uint64_t rotating = 0)
{
bool pass = true;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
}
else
{
return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
}
};
ck::index_t Scale_Stride_BN = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
? ((K + ScaleBlockK - 1) / ScaleBlockK)
: N;
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(BatchSize, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(BatchSize, K, N, StrideB, BatchStrideB, BLayout{}));
Tensor<BDataType> b_g_k_n_permute(
f_host_tensor_descriptor(BatchSize, K, N, StrideB, BatchStrideB, BLayout{}));
Tensor<BScaleDataType> b1_g_k_n(f_host_tensor_descriptor(
BatchSize,
(K + ScaleBlockK - 1) / ScaleBlockK, // K direction group size is ScaleBlockK
N, // N direction group size is 1
Scale_Stride_BN,
BatchStrideScaleB,
BLayout{}));
Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchSize, M, N, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchSize, M, N, StrideC, BatchStrideC, CLayout{}));
int total_gemm_needed = a_g_m_k.GetElementSpaceSizeInBytes() +
b_g_k_n.GetElementSpaceSizeInBytes() +
b1_g_k_n.GetElementSpaceSizeInBytes();
int rotating_count = std::max(
1,
std::min(n_iter,
static_cast<int>(std::ceil(static_cast<double>(rotating) / total_gemm_needed))));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "b1_g_k_n: " << b1_g_k_n.mDesc << std::endl;
std::cout << "c_g_m_n: " << c_g_m_n_device_result.mDesc << std::endl;
std::cout << "rotating count: " << rotating_count << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
break;
case 2:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
b1_g_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(BScaleDataType) * b1_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b1_device_buf.ToDevice(b1_g_k_n.mData.data());
using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemmV2BScale<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
BScaleDataType,
CDataType,
1,
ScaleBlockK,
AElementOp,
BElementOp,
CElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM
if(do_verification)
{
Tensor<float> b_g_k_n_dequant({K, N});
float v_b = 0;
for(int bs = 0; bs < BatchSize; bs++)
{
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
ck::pk_i4_t i4x2 = b_g_k_n(bs, k, n).data;
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
i4 = i4 - 8;
v_b = ck::type_convert<float>(i4);
b_g_k_n_dequant(bs, k, n) =
ck::type_convert<float>(v_b) *
ck::type_convert<float>(b1_g_k_n(bs, k / ScaleBlockK, n));
}
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp,
ComputeDataType>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_g_m_k,
b_g_k_n_dequant,
c_g_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
}
std::string best_op_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
{
const int KPerBlock = op_ptr->GetKPerBlock();
if(op_ptr->GetPermuteB())
{
int K1 = KPerBlock;
int K0 = K / KPerBlock;
// int K0, N, K1
for(int bs = 0; bs < BatchSize; bs++)
{
for(int j = 0; j < K0; j++)
{
for(int i = 0; i < N; i++)
{
for(int jj = 0; jj < K1; jj++)
{
b_g_k_n_permute(bs * BatchStrideB + j * N * K1 + i * K1 + jj) =
b_g_k_n(bs * BatchStrideB + i * K + (j * K1 + jj));
}
}
}
}
if(is_same_v<BDataType, pk_i4_t> && is_same_v<ADataType, half_t>)
{
// vector pk_i4x4 permute
for(int bs = 0; bs < BatchSize; bs++)
{
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int input[8];
for(int k = 0; k < 4; k++)
{
int i4x2 = b_g_k_n_permute(bs, j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
b_g_k_n_permute(bs, j + 0, i) = i4x2;
}
{
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_g_k_n_permute(bs, j + 2, i) = i4x2;
}
{
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_g_k_n_permute(bs, j + 4, i) = i4x2;
}
{
int hi = input[7];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_g_k_n_permute(bs, j + 6, i) = i4x2;
}
}
}
}
}
}
else
{
b_g_k_n_permute = b_g_k_n;
}
b_device_buf.ToDevice(b_g_k_n_permute.mData.data());
std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38};
if(KBatch > 0)
{
kbatch_list = {KBatch};
}
for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
Scale_Stride_BN,
BatchStrideA,
BatchStrideB,
BatchStrideC,
BatchStrideScaleB,
static_cast<BScaleDataType*>(b1_device_buf.GetDeviceBuffer()),
BatchSize, // Batch count
kbatch_curr, // Split K count
a_element_op,
b_element_op,
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
// invoker_ptr->Run(argument_ptr.get(),
// StreamConfig{nullptr, false, 0, n_warmup, n_iter});
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0});
if(do_verification)
{
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
#if defined CK_ENABLE_FP8
// set softer tolerances for fp8
if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<CDataType, f8_t>)
{
std::string msg = "Error: Incorrect results!";
double rtol = 1e-1;
double atol = 1e-1;
pass =
pass & ck::utils::check_err(
c_g_m_n_device_result, c_g_m_n_host_result, msg, rtol, atol);
}
else
{
#endif
pass =
pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
#if defined CK_ENABLE_FP8
}
#endif
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_g_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
<< std::endl;
}
}
std::string op_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
time_kernel,
0,
n_warmup,
n_iter,
rotating_count > 1,
rotating_count});
std::size_t flop = std::size_t(2) * M * N * K * BatchSize;
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
return 2;
else
return 1;
}();
std::size_t num_btype = sizeof(ADataType) * M * K +
sizeof(BDataType) * K * N / BPackedSize +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl;
if(tflops > best_tflops && ave_time > 1e-10)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
}
if constexpr(is_same<CDataType, float>::value)
{
std::cout << "Best Perf for datatype = f32";
}
else if constexpr(is_same<CDataType, half_t>::value)
{
std::cout << "Best Perf for datatype = f16";
}
else if constexpr(is_same<CDataType, bhalf_t>::value)
{
std::cout << "Best Perf for datatype = bf16";
}
else if constexpr(is_same<CDataType, int8_t>::value)
{
std::cout << "Best Perf for datatype = int8";
}
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " ALayout = RowMajor";
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " ALayout = ColumnMajor";
}
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " BLayout = RowMajor";
}
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " BLayout = ColumnMajor";
}
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck
...@@ -59,6 +59,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") ...@@ -59,6 +59,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_b_scale.cpp) list(APPEND PROFILER_SOURCES profile_gemm_b_scale.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_b_scale.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_universal_batched.cpp) list(APPEND PROFILER_SOURCES profile_gemm_universal_batched.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_universal_reduce.cpp) list(APPEND PROFILER_SOURCES profile_gemm_universal_reduce.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp) list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp)
...@@ -143,6 +144,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") ...@@ -143,6 +144,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_b_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_b_scale_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_b_scale_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_batched_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_batched_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include "profiler/profile_batched_gemm_b_scale_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
F8_F16_F16, // 4
F16_F8_F16, // 5
F16_F16_F16_F8, // 6
F8_F8_BF16, // 7
F16_I4_F16, // 8
};
enum struct BScaleBlockTile
{
K_64, // 0
K_128, // 1
};
#define OP_NAME "batched_gemm_b_scale"
#define OP_DESC "Int4-dequant batched GEMM"
int profile_batched_gemm_b_scale(int argc, char* argv[])
{
if(argc != 17 && argc != 20)
{
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8; 8: f16@i4)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg4: B scale block tile (0: 64, 1: 128):\n");
printf("arg5: verification (0: no; 1: yes)\n");
printf("arg6: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg7: print tensor value (0: no; 1: yes)\n");
printf("arg8: time kernel (0=no, 1=yes)\n");
printf("arg9 to 15: M, N, K, StrideA, StrideB, StrideC, BatachCount\n");
printf("arg16: split k into mulitiple batch\n");
printf("optional:\n");
printf("arg17: number of warm-up cycles (default 1)\n");
printf("arg18: number of iterations (default 10)\n");
printf("arg19: memory for rotating buffer (default 0, size in MB)\n");
exit(1);
}
printf("Start profiling\n");
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const auto B_scale_block = static_cast<BScaleBlockTile>(std::stoi(argv[4]));
const bool do_verification = std::stoi(argv[5]);
const int init_method = std::stoi(argv[6]);
const bool do_log = std::stoi(argv[7]);
const bool time_kernel = std::stoi(argv[8]);
const int M = std::stoi(argv[9]);
const int N = std::stoi(argv[10]);
const int K = std::stoi(argv[11]);
const int StrideA = std::stoi(argv[12]);
const int StrideB = std::stoi(argv[13]);
const int StrideC = std::stoi(argv[14]);
const int BatchStrideA = M * N;
const int BatchStrideB = N * K;
const int BatchStrideC = M * N;
const int BatchStrideScaleB =
(K + static_cast<int>(B_scale_block) - 1) / static_cast<int>(B_scale_block) * N;
const int BatchSize = std::stoi(argv[15]);
const int KBatch = std::stoi(argv[16]);
printf("M:%d, N:%d, K:%d, StrideA:%d, StrideB:%d, StrideC:%d, BatchStrideA:%d, "
"BatchStrideB:%d, BatchStrideC:%d, BatchStrideScaleB:%d, BatchSize:%d, KBatch:%d,\n",
M,
N,
K,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
BatchStrideScaleB,
BatchSize,
KBatch);
int n_warmup = 1;
int n_iter = 10;
uint64_t rotating = 0;
if(argc == 20)
{
n_warmup = std::stoi(argv[17]);
n_iter = std::stoi(argv[18]);
rotating = std::stoull(argv[19]) * 1024 * 1024;
printf("n_warmup:%d, n_iter:%d, rotating:%lu\n", n_warmup, n_iter, rotating);
}
using F32 = float;
using F16 = ck::half_t;
using I4 = ck::pk_i4_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_type,
auto b_type,
auto b_scale_type,
auto comp_type,
auto acc_type,
auto c_type,
auto scale_block_k,
auto a_layout,
auto b_layout,
auto c_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using BScaleDataType = decltype(b_scale_type);
using ComputeDataType = decltype(comp_type);
using AccDataType = decltype(acc_type);
using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
bool pass = ck::profiler::profile_batched_gemm_b_scale_impl<ADataType,
BDataType,
BScaleDataType,
ComputeDataType,
AccDataType,
CDataType,
scale_block_k,
ALayout,
BLayout,
CLayout>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
BatchStrideScaleB,
BatchSize,
KBatch,
n_warmup,
n_iter,
rotating);
return pass ? 0 : 1;
};
if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN &&
B_scale_block == BScaleBlockTile::K_128)
{
printf("F16_I4_F16 MK_NK_MN K_128\n");
return profile(
F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<128>{}, Row{}, Col{}, Row{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_batched_gemm_b_scale);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment