Commit 03c2ba3a authored by aska-0096's avatar aska-0096
Browse files

bug fix + performance opt + clangformat

parent 1a324dfb
......@@ -170,7 +170,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
//BScale Thread Copy
// BScale Thread Copy
typename BScaleGridBuffer,
typename BScaleGridDesc,
typename BScaleThreadDesc,
......
......@@ -719,12 +719,12 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
const BScaleGridDesc& b_scale_grid_desc,
//BScaleThreadCopy
// BScaleThreadCopy
const BScaleThreadDesc& b_scale_thread_desc,
BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf,
const BScaleThreadTransferStep& b_scale_thread_copy_step,
//num loop
// num loop
index_t num_loop,
index_t num_loop_per_scale) const
{
......@@ -864,14 +864,14 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
}
});
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t)
// {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier(0);
......@@ -986,7 +986,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier(0);
......@@ -1084,7 +1083,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
});
});
__builtin_amdgcn_sched_barrier(0);
......
......@@ -295,25 +295,24 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
//BScaleThreadCopy
// BScaleThreadCopy
const BScaleGridDesc& b_scale_grid_desc,
const BScaleThreadDesc& b_scale_thread_desc,
BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf,
const BScaleThreadTransferStep& b_scale_thread_copy_step,
//num loop
// num loop
index_t num_loop,
index_t num_loop_per_scale) const
{
__builtin_amdgcn_sched_barrier(0);
ignore = num_loop_per_scale;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize());
//B scale buffer
// B scale buffer
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_scale_thread_desc.GetElementSpaceSize());
......@@ -328,14 +327,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, I0, I0),
make_tuple(n0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<0>{}));
});
if(num_loop_per_scale == 1)
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
}
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
......@@ -350,7 +358,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// Initialize C
c_thread_buf.Clear();
auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>(); // need actually?
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccDataType,
1,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_per_scale;
// Local prefetch 1
block_sync_lds();
......@@ -415,8 +429,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
......@@ -455,15 +468,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, I0, I0),
make_tuple(n0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
if((i + 2) % num_loop_per_scale == 0)
{
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{}));
}
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
......
......@@ -268,13 +268,13 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
//BScaleThreadCopy
// BScaleThreadCopy
const BScaleGridDesc& b_scale_grid_desc,
const BScaleThreadDesc& b_scale_thread_desc,
BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf,
const BScaleThreadTransferStep& b_scale_thread_copy_step,
//num loop
// num loop
index_t num_loop,
index_t num_loop_per_scale) const
{
......@@ -284,7 +284,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize());
//B scale buffer
// B scale buffer
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_scale_thread_desc.GetElementSpaceSize());
......@@ -409,11 +409,11 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
make_tuple(n0, I0),
b_scale_thread_bufs(lds_read_reg_buf));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<0>{}));
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{}));
a_blockwise_copy.RunWrite(
a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
......@@ -426,7 +426,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
......@@ -452,8 +451,7 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
......@@ -462,7 +460,8 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(b_scale_thread_bufs(mfma_reg_buf)[n0]);
type_convert<AccDataType>(
b_scale_thread_bufs(mfma_reg_buf)[n0]);
});
});
});
......@@ -521,7 +520,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
......@@ -640,7 +638,6 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
};
auto CompFunc = [&](auto mfma_reg_buf) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
......
......@@ -218,7 +218,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
};
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) > 128 * 128 * 64 * 2)
? 1
: 2
: 2;
if(has_main_k_block_loop)
{
......@@ -664,7 +669,20 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, p_b_scale, KBatch, a_element_op, b_element_op, c_element_op};
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
p_b_scale,
KBatch,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......
......@@ -13,9 +13,9 @@ namespace ck {
__host__ __device__ inline half4_t pki4_to_half4(int q)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
constexpr int LO = 0x000f000f;
constexpr int HI = 0x00f000f0;
constexpr int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
......@@ -23,9 +23,9 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
int hi = amd_assembly_and_or_b32(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0xE408E408; //-8
const int MUL = 0x2c002c00; // 1/16
const int ADD = 0xd480d480; //-79
constexpr int SUB = 0xE408E408; //-8
constexpr int MUL = 0x2c002c00; // 1/16
constexpr int ADD = 0xd480d480; //-79
vector_type<half_t, 4> res;
......@@ -34,7 +34,15 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD));
#if 0
asm volatile("v_and_or_b32 %0, %4, %5, %7 \n \
v_and_or_b32 %1, %4, %6, %7 \n \
v_pk_add_f16 %2, %0, %8 \n \
v_pk_fma_f16 %3, %1, %9, %10 \
"
: "=v"(lo), "=v"(hi), "=v"(res.template AsType<half2_t>()(Number<0>{})), "=v"(res.template AsType<half2_t>()(Number<1>{}))
: "v"(q), "v"(LO), "v"(HI), "s"(EX), "s"(SUB), "v"(MUL), "s"(ADD), "0"(lo), "1"(hi));
#endif
return res.template AsType<half4_t>()[Number<0>{}];
}
......
......@@ -48,12 +48,7 @@ __global__ void
// karg);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg.p_b_scale_grid,
p_shared,
karg);
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, karg.p_b_scale_grid, p_shared, karg);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
......@@ -1303,11 +1298,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
// B Scale grid and buffer
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK),
ScaleBlockK),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK),
1,
0));
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
......@@ -1438,12 +1430,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
KPerBlock);
// b scale
static_assert(KPerBlock<=ScaleBlockK);
static_assert(KPerBlock <= ScaleBlockK);
const index_t ScaleSliceSizeN = NXdlPerWave;
const index_t ScaleSliceSizeK = 1;
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}, Number<1>{}));
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
......@@ -1455,22 +1447,24 @@ struct GridwiseGemm_xdl_cshuffle_v3
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<1, ScaleSliceSizeK, 1>,
Sequence<0, 1, 2>,
Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
1,
1,
false>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0, 0));
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0));
constexpr auto b_scale_thread_slice_copy_step =
make_tuple(make_multi_index(NWaves * NPerXdl, 0, 0),
make_multi_index(-NPerBlock, KPerBlock/ScaleBlockK, KPerBlock%ScaleBlockK));
make_tuple(make_multi_index(NWaves * NPerXdl, 0),
make_multi_index(-NPerBlock, 0),
make_multi_index(-NPerBlock, 1));
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
......@@ -1867,7 +1861,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
a_block_space_size_aligned * sizeof(ADataType)/APackedSize),
a_block_space_size_aligned * sizeof(ADataType) / APackedSize),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
......@@ -1875,7 +1869,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
a_block_space_size_aligned * sizeof(ADataType)/APackedSize),
a_block_space_size_aligned * sizeof(ADataType) / APackedSize),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
......@@ -1924,7 +1918,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
......
......@@ -40,8 +40,8 @@ template <typename ADataType,
typename BLayout,
typename CLayout,
index_t ScaleBlockK>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmV2BScale<ALayout,
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmV2BScale<
ALayout,
BLayout,
CLayout,
ADataType,
......
......@@ -8,13 +8,22 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2BScale<Row, Col, Row, F16, I4, F16, F16, 1, 128, PassThrough, PassThrough, PassThrough>>>&
instances)
std::vector<std::unique_ptr<DeviceGemmV2BScale<Row,
Col,
Row,
F16,
I4,
F16,
F16,
1,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
//device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{});
// device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{});
device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Intrawave, GemmDefault>{});
}
......
......@@ -72,7 +72,8 @@ bool profile_gemm_b_scale_impl(int do_verification,
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BScaleDataType> b1_k_n(f_host_tensor_descriptor((K + ScaleBlockK - 1) / ScaleBlockK, // K direction group size is ScaleBlockK
Tensor<BScaleDataType> b1_k_n(f_host_tensor_descriptor(
(K + ScaleBlockK - 1) / ScaleBlockK, // K direction group size is ScaleBlockK
N, // N direction group size is 1
Scale_Stride_BN,
BLayout{}));
......@@ -167,8 +168,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
i4 = i4 - 8;
v_b = ck::type_convert<float>(i4);
b_k_n_dequant(k, n) =
ck::type_convert<float>(v_b) *
b_k_n_dequant(k, n) = ck::type_convert<float>(v_b) *
ck::type_convert<float>(b1_k_n(k / ScaleBlockK, n));
}
}
......@@ -291,8 +291,8 @@ bool profile_gemm_b_scale_impl(int do_verification,
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
......
......@@ -82,7 +82,14 @@ int profile_gemm_b_scale(int argc, char* argv[])
const int StrideB = std::stoi(argv[13]);
const int StrideC = std::stoi(argv[14]);
const int KBatch = std::stoi(argv[15]);
printf("M:%d, N:%d, K:%d, StrideA:%d, StrideB:%d, StrideC:%d, KBatch:%d\n", M, N, K, StrideA, StrideB, StrideC, KBatch);
printf("M:%d, N:%d, K:%d, StrideA:%d, StrideB:%d, StrideC:%d, KBatch:%d\n",
M,
N,
K,
StrideA,
StrideB,
StrideC,
KBatch);
int n_warmup = 1;
int n_iter = 10;
......@@ -156,14 +163,18 @@ int profile_gemm_b_scale(int argc, char* argv[])
return pass ? 0 : 1;
};
// if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN && B_scale_block == BScaleBlockTile::K_64)
// if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN &&
// B_scale_block == BScaleBlockTile::K_64)
// {
// return profile(F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<64>{}, Row{}, Col{}, Row{});
// return profile(F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<64>{}, Row{}, Col{},
// Row{});
// }
if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN && B_scale_block == BScaleBlockTile::K_128)
if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN &&
B_scale_block == BScaleBlockTile::K_128)
{
printf("F16_I4_F16 MK_NK_MN K_128\n");
return profile(F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<128>{}, Row{}, Col{}, Row{});
return profile(
F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<128>{}, Row{}, Col{}, Row{});
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment