Unverified Commit a11cf2c6 authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into codegen_hiprtc

parents a72e9efa 64d5c4d6
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
namespace ck_tile { namespace ck_tile {
// Y = X * XScale, QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale) // Y = X * SmoothScale, QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
template <typename XDataType_, template <typename XDataType_,
typename XScaleDataType_, typename SmoothScaleDataType_,
typename ComputeDataType_, typename ComputeDataType_,
typename YScaleDataType_, typename YScaleDataType_,
typename QYDataType_, typename QYDataType_,
...@@ -18,12 +18,12 @@ template <typename XDataType_, ...@@ -18,12 +18,12 @@ template <typename XDataType_,
bool kTwoPass_> bool kTwoPass_>
struct SmoothquantPipelineProblem struct SmoothquantPipelineProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>; using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>; using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using QYDataType = remove_cvref_t<QYDataType_>; using QYDataType = remove_cvref_t<QYDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -16,11 +16,11 @@ struct SmoothquantPipelineTwoPass ...@@ -16,11 +16,11 @@ struct SmoothquantPipelineTwoPass
using Problem = ck_tile::remove_cvref_t<Problem_>; using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XScaleDataType = ck_tile::remove_cvref_t<typename Problem::XScaleDataType>; using SmoothScaleDataType = ck_tile::remove_cvref_t<typename Problem::SmoothScaleDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>; using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>; using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
...@@ -39,9 +39,12 @@ struct SmoothquantPipelineTwoPass ...@@ -39,9 +39,12 @@ struct SmoothquantPipelineTwoPass
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
template <typename XWindow, typename XScaleWindow, typename QYWindow, typename YScaleWindow> template <typename XWindow,
typename SmoothScaleWindow,
typename QYWindow,
typename YScaleWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XScaleWindow& xscale_window_, const SmoothScaleWindow& smscale_window_,
YScaleWindow& yscale_window, YScaleWindow& yscale_window,
QYWindow& qy_window, QYWindow& qy_window,
ck_tile::index_t row_size, ck_tile::index_t row_size,
...@@ -49,8 +52,8 @@ struct SmoothquantPipelineTwoPass ...@@ -49,8 +52,8 @@ struct SmoothquantPipelineTwoPass
{ {
auto x_window = auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>()); make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto xscale_window = make_tile_window( auto smscale_window = make_tile_window(
xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>()); smscale_window_, Policy::template MakeSmoothScaleBlockTileDistribution<Problem>());
static constexpr index_t Block_N = Problem::BlockShape::Block_N; static constexpr index_t Block_N = Problem::BlockShape::Block_N;
index_t num_n_tile_iteration = index_t num_n_tile_iteration =
...@@ -76,14 +79,14 @@ struct SmoothquantPipelineTwoPass ...@@ -76,14 +79,14 @@ struct SmoothquantPipelineTwoPass
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
const auto x = load_tile(x_window); const auto x = load_tile(x_window);
const auto xscale = load_tile(xscale_window); const auto smscale = load_tile(smscale_window);
const auto y = tile_elementwise_in( const auto y = tile_elementwise_in(
[&](const auto& a, const auto& b) { [&](const auto& a, const auto& b) {
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b); return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
}, },
x, x,
xscale); smscale);
constexpr auto x_size_per_row = constexpr auto x_size_per_row =
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{}); x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
...@@ -94,7 +97,7 @@ struct SmoothquantPipelineTwoPass ...@@ -94,7 +97,7 @@ struct SmoothquantPipelineTwoPass
block_reduce2d(y, absmax, reduce_absmax_func); block_reduce2d(y, absmax, reduce_absmax_func);
move_tile_window(x_window, {0, Block_N}); move_tile_window(x_window, {0, Block_N});
move_tile_window(xscale_window, {Block_N}); move_tile_window(smscale_window, {Block_N});
} }
// compute absmax, cross-lane->cross-warp // compute absmax, cross-lane->cross-warp
...@@ -114,31 +117,31 @@ struct SmoothquantPipelineTwoPass ...@@ -114,31 +117,31 @@ struct SmoothquantPipelineTwoPass
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(xscale_window, {-Block_N}); move_tile_window(smscale_window, {-Block_N});
move_tile_window(qy_window, {0, stride_to_right_most_window}); move_tile_window(qy_window, {0, stride_to_right_most_window});
// recompute y and quantize y to qy // recompute y and quantize y to qy
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
const auto x = load_tile(x_window); const auto x = load_tile(x_window);
const auto xscale = load_tile(xscale_window); const auto smscale = load_tile(smscale_window);
const auto y = tile_elementwise_in( const auto y = tile_elementwise_in(
[&](const auto& a, const auto& b) { [&](const auto& a, const auto& b) {
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b); return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
}, },
x, x,
xscale); smscale);
auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution()); auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
sweep_tile(qy, [&](auto idx) { sweep_tile(qy, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
auto qy_ = y[idx] / yscale[i_idx]; auto qy_ = y[idx] / yscale[i_idx];
qy(idx) = saturates<QYDataType>{}(qy_); qy(idx) = type_convert<QYDataType>(saturates<QYDataType>{}(qy_));
}); });
store_tile(qy_window, qy); store_tile(qy_window, qy);
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(xscale_window, {0, -Block_N}); move_tile_window(smscale_window, {0, -Block_N});
move_tile_window(qy_window, {0, -Block_N}); move_tile_window(qy_window, {0, -Block_N});
} }
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -73,39 +73,9 @@ struct ReferencefpAintBGemm : public device::BaseOperator ...@@ -73,39 +73,9 @@ struct ReferencefpAintBGemm : public device::BaseOperator
ScaleDataType v_scale; ScaleDataType v_scale;
ADataType v_converted_b; ADataType v_converted_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation arg.a_element_op_(v_a, arg.a_m_k_(m, k));
if constexpr(is_same_v<AElementwiseOperation, arg.b_element_op_(v_b, arg.b_k_n_(k, n));
ck::tensor_operation::element_wise::ConvertBF16RTN>) arg.b_element_op_(v_scale, arg.scale_k_n_(k, n));
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
}
// same for B matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
}
// same for scale matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_scale,
arg.scale_k_n_(k, n));
}
else
{
arg.b_element_op_(v_scale, arg.scale_k_n_(k, n));
}
v_converted_b = type_convert<ADataType>(v_b) * v_scale; v_converted_b = type_convert<ADataType>(v_b) * v_scale;
v_acc += ck::type_convert<AccDataType>(v_a) * v_acc += ck::type_convert<AccDataType>(v_a) *
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -68,13 +68,7 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -68,13 +68,7 @@ struct ReferenceGemm : public device::BaseOperator
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
// use PassThrough instead of ConvertBF16RTN for reference calculation if constexpr(is_same_v<ADataType, pk_i4_t>)
if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else if constexpr(is_same_v<ADataType, pk_i4_t>)
{ {
uint8_t i4x2 = arg.a_m_k_(m, k).data; uint8_t i4x2 = arg.a_m_k_(m, k).data;
int8_t i4 = 0; int8_t i4 = 0;
...@@ -89,13 +83,8 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -89,13 +83,8 @@ struct ReferenceGemm : public device::BaseOperator
{ {
arg.a_element_op_(v_a, arg.a_m_k_(m, k)); arg.a_element_op_(v_a, arg.a_m_k_(m, k));
} }
// same for B matrix
if constexpr(is_same_v<BElementwiseOperation, if constexpr(is_same_v<BDataType, pk_i4_t>)
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
}
else if constexpr(is_same_v<BDataType, pk_i4_t>)
{ {
uint8_t i4x2 = arg.b_k_n_(k, n).data; uint8_t i4x2 = arg.b_k_n_(k, n).data;
int8_t i4 = 0; int8_t i4 = 0;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -74,26 +74,8 @@ struct ReferenceGemmMultipleD : public device::BaseOperator ...@@ -74,26 +74,8 @@ struct ReferenceGemmMultipleD : public device::BaseOperator
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
// use PassThrough instead of ConvertBF16RTN for reference calculation arg.a_element_op_(v_a, arg.a_m_k_(m, k));
if constexpr(is_same_v<AElementwiseOperation, arg.b_element_op_(v_b, arg.b_k_n_(k, n));
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
}
// same for B matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
}
v_acc += v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b); ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
......
...@@ -126,6 +126,35 @@ void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances( ...@@ -126,6 +126,35 @@ void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
// bf16_inputA bf16_inputB
#if defined(CK_ENABLE_BF16)
void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif // CK_ENABLE_BF16
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename ELayout, typename ELayout,
...@@ -227,6 +256,24 @@ struct DeviceOperationInstanceFactory< ...@@ -227,6 +256,24 @@ struct DeviceOperationInstanceFactory<
} }
#endif #endif
// bf16_inputA bf16_inputB
#if defined(CK_ENABLE_BF16)
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<EDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs);
}
}
#endif // CK_ENABLE_BF16
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -8,6 +8,8 @@ list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16 ...@@ -8,6 +8,8 @@ list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16
device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp) device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp)
add_instance_library(device_grouped_gemm_fixed_nk_instance ${GROUPED_GEMM_FIXED_NK_INSTANCES}) add_instance_library(device_grouped_gemm_fixed_nk_instance ${GROUPED_GEMM_FIXED_NK_INSTANCES})
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using DsDataType = ck::Tuple<>;
using DsLayout = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_irregular_tile_instances =
std::tuple<
// clang-format off
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// clang-format on
>;
void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_irregular_tile_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using DsDataType = ck::Tuple<>;
using DsLayout = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_irregular_tile_instances =
std::tuple<
// clang-format off
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;
void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_irregular_tile_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -342,7 +342,7 @@ bool profile_gemm_b_scale_impl(int do_verification, ...@@ -342,7 +342,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<int8_t>(std::cout << "b: ", b_k_n.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",") std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl; << std::endl;
......
...@@ -17,11 +17,11 @@ enum struct GemmMatrixLayout ...@@ -17,11 +17,11 @@ enum struct GemmMatrixLayout
enum struct GemmDataType enum struct GemmDataType
{ {
BF16_I8_BF16, // 0 BF16_I8_BF16, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
F16_F8_F16, // 2 F16_F8_F16, // 2
F16_I8_F16, // 3 F16_I8_F16, // 3
BF16_BF16_BF16 // 4
}; };
#define OP_NAME "grouped_gemm_fixed_nk" #define OP_NAME "grouped_gemm_fixed_nk"
...@@ -39,7 +39,6 @@ std::vector<int> argToIntArray(char* input) ...@@ -39,7 +39,6 @@ std::vector<int> argToIntArray(char* input)
{ {
out.push_back(std::stoi(item)); out.push_back(std::stoi(item));
} }
return out; return out;
} }
...@@ -83,14 +82,6 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -83,14 +82,6 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
const auto StrideCs = argToIntArray(argv[13]); const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc >= 15 ? std::stoi(argv[14]) : 1; const int kbatch = argc >= 15 ? std::stoi(argv[14]) : 1;
using F32 = float;
using F16 = ck::half_t;
#if defined(CK_ENABLE_FP8)
using F8 = ck::f8_t;
#endif
using BF16 = ck::bhalf_t;
using I8 = int8_t;
int n_warmup = 1; int n_warmup = 1;
int n_iter = 10; int n_iter = 10;
if(argc == 17) if(argc == 17)
...@@ -99,13 +90,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -99,13 +90,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_iter = std::stoi(argv[16]); n_iter = std::stoi(argv[16]);
} }
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
I8, ck::half_t,
BF16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -123,12 +113,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -123,12 +113,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
I8, ck::half_t,
BF16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -146,14 +136,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -146,14 +136,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
#endif #if defined(CK_ENABLE_FP8)
#if defined(CK_ENABLE_FP16) else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
F16, ck::f8_t,
F16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -171,12 +160,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -171,12 +160,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
F16, ck::f8_t,
F16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -194,14 +183,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -194,14 +183,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
#endif #endif // CK_ENABLE_FP8
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) #if defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
F8, int8_t,
F16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -219,12 +208,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -219,12 +208,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
F8, int8_t,
F16, ck::half_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -242,14 +231,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -242,14 +231,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
#endif #endif // CK_ENABLE_INT8
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8) #if defined(CK_ENABLE_BF16)
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
I8, ck::bhalf_t,
F16, ck::bhalf_t,
F32, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -267,12 +256,59 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -267,12 +256,59 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup, n_warmup,
n_iter); n_iter);
} }
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
#if defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16, ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
I8, int8_t,
F16, ck::bhalf_t,
F32, float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
int8_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -286,11 +322,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -286,11 +322,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs,
1, kbatch,
n_warmup, n_warmup,
n_iter); n_iter);
} }
#endif #endif // CK_ENABLE_INT8
#endif // CK_ENABLE_BF16
else else
{ {
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
......
...@@ -7,6 +7,34 @@ include(gtest) ...@@ -7,6 +7,34 @@ include(gtest)
add_custom_target(tests) add_custom_target(tests)
# list of tests that are labelled as REGRESSION_TEST for make regression (runtime more than 30 seconds)
# all other tests are labelled as SMOKE_TEST
set(REGRESSION_TESTS
test_gemm_standalone_xdl_fp16
test_gemm_fp16
test_gemm_splitk
test_batched_gemm
test_gemm_universal
test_batched_gemm_softmax_gemm_fp16
test_batched_gemm_softmax_gemm_permute_fp16
test_batched_gemm_bias_softmax_gemm_permute_fp16
test_batched_gemm_softmax_gemm_permute_bf16
test_batched_gemm_bias_softmax_gemm_permute_bf16
test_grouped_gemm_splitk
test_reduce_no_index
test_reduce_with_index
test_convnd_fwd
test_convnd_bwd_data
test_grouped_convnd_fwd
test_grouped_convnd_bwd_weight
test_softmax_rank3
test_softmax_rank4
test_batchnorm_fwd_rank_4
test_batchnorm_bwd_rank_4
test_grouped_convnd_bwd_data_xdl
test_conv_tensor_rearrange
)
function(add_test_executable TEST_NAME) function(add_test_executable TEST_NAME)
message("adding test ${TEST_NAME}") message("adding test ${TEST_NAME}")
set(result 1) set(result 1)
...@@ -88,6 +116,15 @@ function(add_test_executable TEST_NAME) ...@@ -88,6 +116,15 @@ function(add_test_executable TEST_NAME)
endif() endif()
#message("add_test returns ${result}") #message("add_test returns ${result}")
set(result ${result} PARENT_SCOPE) set(result ${result} PARENT_SCOPE)
if(result EQUAL 0 AND NOT "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
message("adding to SMOKE TEST FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "SMOKE_TEST")
add_dependencies(smoke ${TEST_NAME})
elseif(result EQUAL 0 AND "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
message("Adding to REGRESSION TEST FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "REGRESSION_TEST")
add_dependencies(regression ${TEST_NAME})
endif()
endfunction() endfunction()
function(add_gtest_executable TEST_NAME) function(add_gtest_executable TEST_NAME)
...@@ -168,6 +205,15 @@ function(add_gtest_executable TEST_NAME) ...@@ -168,6 +205,15 @@ function(add_gtest_executable TEST_NAME)
endif() endif()
#message("add_gtest returns ${result}") #message("add_gtest returns ${result}")
set(result ${result} PARENT_SCOPE) set(result ${result} PARENT_SCOPE)
if(result EQUAL 0 AND NOT "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
#message("adding to smoke test FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "SMOKE_TEST")
add_dependencies(smoke ${TEST_NAME})
elseif(result EQUAL 0 AND "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
#message("Adding to REGRESSION TEST FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "REGRESSION_TEST")
add_dependencies(regression ${TEST_NAME})
endif()
endfunction() endfunction()
add_compile_options(-Wno-c++20-extensions) add_compile_options(-Wno-c++20-extensions)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <sstream> #include <sstream>
...@@ -61,7 +61,7 @@ class TestCkTileBatchedGemm : public ::testing::Test ...@@ -61,7 +61,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t< using GemmEpilogue = std::conditional_t<
CShuffleEpilogue, CShuffleEpilogue,
...@@ -73,8 +73,8 @@ class TestCkTileBatchedGemm : public ::testing::Test ...@@ -73,8 +73,8 @@ class TestCkTileBatchedGemm : public ::testing::Test
kOutputRank, kOutputRank,
1, 1,
0, 0,
TilePartitioner::kM, TilePartitioner::MPerBlock,
TilePartitioner::kN>>, TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
......
...@@ -59,7 +59,7 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -59,7 +59,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
......
...@@ -49,3 +49,4 @@ if(result EQUAL 0) ...@@ -49,3 +49,4 @@ if(result EQUAL 0)
endif() endif()
add_gtest_executable(test_type_convert_const type_convert_const.cpp) add_gtest_executable(test_type_convert_const type_convert_const.cpp)
add_gtest_executable(test_bhalf test_bhalf.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bhalf_t;
using ck::type_convert;
TEST(BHALF_T, Nan)
{
const uint16_t binary_bhalf_nan = 0x7FC0;
const bhalf_t bhalf_nan = ck::bit_cast<bhalf_t>(binary_bhalf_nan);
EXPECT_EQ(bhalf_nan, type_convert<bhalf_t>(ck::NumericLimits<float>::QuietNaN()));
}
TEST(BHALF_T, Inf)
{
const uint16_t binary_bhalf_inf = 0x7F80;
const bhalf_t bhalf_inf = ck::bit_cast<bhalf_t>(binary_bhalf_inf);
EXPECT_EQ(bhalf_inf, type_convert<bhalf_t>(ck::NumericLimits<float>::Infinity()));
}
TEST(BHALF_T, MantisaOverflow)
{
const float abs_tol = std::pow(2, -7);
const uint32_t val = 0x81FFFFFF;
const float float_val = ck::bit_cast<float>(val);
ASSERT_NEAR(float_val, type_convert<float>(type_convert<bhalf_t>(float_val)), abs_tol);
}
TEST(BHALF_T, ExpOverflow)
{
const uint32_t val = 0xFF800000;
const float float_val = ck::bit_cast<float>(val);
ASSERT_EQ(type_convert<float>(type_convert<bhalf_t>(float_val)), float_val);
}
TEST(BHALF_T, MantisaExpOverflow)
{
const uint32_t val = 0xFFFFFFFF;
const float float_val = ck::bit_cast<float>(val);
ASSERT_TRUE(std::isnan(float_val));
ASSERT_TRUE(std::isnan(type_convert<float>(type_convert<bhalf_t>(float_val))));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment