Commit 88b978c5 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents e4112de7 6fb1f4e0
...@@ -6,6 +6,22 @@ ...@@ -6,6 +6,22 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
namespace ck_tile { namespace ck_tile {
/*
* construct this structure with behavior as:
*
* // create stream config with default stream(NULL), and not timing the kernel
* stream_config s = stream_config{};
*
* // create stream config with _some_stream_id_, and not timing the kernel
* stream_config s = stream_config{_some_stream_id_};
*
* // create stream config with _some_stream_id_, and benchmark with warmup/repeat as default
* stream_config s = stream_config{_some_stream_id_, true};
*
* // create stream config with _some_stream_id_, and benchmark using cpu timer
* stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, false};
**/
struct stream_config struct stream_config
{ {
hipStream_t stream_id_ = nullptr; hipStream_t stream_id_ = nullptr;
...@@ -13,5 +29,6 @@ struct stream_config ...@@ -13,5 +29,6 @@ struct stream_config
int log_level_ = 0; int log_level_ = 0;
int cold_niters_ = 3; int cold_niters_ = 3;
int nrepeat_ = 10; int nrepeat_ = 10;
bool is_gpu_timer_ = true; // keep compatible
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <hip/hip_runtime.h>
#include <cstddef>
#include <chrono>
namespace ck_tile {
struct gpu_timer
{
CK_TILE_HOST gpu_timer()
{
HIP_CHECK_ERROR(hipEventCreate(&start_evt));
HIP_CHECK_ERROR(hipEventCreate(&stop_evt));
}
CK_TILE_HOST ~gpu_timer() noexcept(false)
{
HIP_CHECK_ERROR(hipEventDestroy(start_evt));
HIP_CHECK_ERROR(hipEventDestroy(stop_evt));
}
CK_TILE_HOST void start(const hipStream_t& s)
{
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipEventRecord(start_evt, s));
}
CK_TILE_HOST void stop(const hipStream_t& s)
{
HIP_CHECK_ERROR(hipEventRecord(stop_evt, s));
HIP_CHECK_ERROR(hipEventSynchronize(stop_evt));
}
// return in ms
CK_TILE_HOST float duration() const
{
float ms = 0;
HIP_CHECK_ERROR(hipEventElapsedTime(&ms, start_evt, stop_evt));
return ms;
}
private:
hipEvent_t start_evt, stop_evt;
};
struct cpu_timer
{
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST void start(const hipStream_t&)
{
HIP_CHECK_ERROR(hipDeviceSynchronize());
start_tick = std::chrono::high_resolution_clock::now();
}
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST void stop(const hipStream_t&)
{
HIP_CHECK_ERROR(hipDeviceSynchronize());
stop_tick = std::chrono::high_resolution_clock::now();
}
// return in ms
CK_TILE_HOST float duration() const
{
double sec =
std::chrono::duration_cast<std::chrono::duration<double>>(stop_tick - start_tick)
.count();
return static_cast<float>(sec * 1e3);
}
private:
std::chrono::time_point<std::chrono::high_resolution_clock> start_tick;
std::chrono::time_point<std::chrono::high_resolution_clock> stop_tick;
};
} // namespace ck_tile
...@@ -23,13 +23,13 @@ VERTICAL: ...@@ -23,13 +23,13 @@ VERTICAL:
[0] 1 2 3 4 5 [0] 1 2 3 4 5
[0] 1 2 3 4 5 [0] 1 2 3 4 5
TOP_LEFT: TOP_LEFT(but negative):
[0] 1 2 3 4 5 [0] 1 2 3 4 5
1 [0] 1 2 3 4 1 [0] 1 2 3 4
2 1 [0] 1 2 3 2 1 [0] 1 2 3
3 2 1 [0] 1 2 3 2 1 [0] 1 2
FROM_BOTTOM_RIGHT: FROM_BOTTOM_RIGHT(but negative):
2 1 [0] 1 2 3 2 1 [0] 1 2 3
3 2 1 [0] 1 2 3 2 1 [0] 1 2
4 3 2 1 [0] 1 4 3 2 1 [0] 1
...@@ -54,7 +54,7 @@ struct Alibi ...@@ -54,7 +54,7 @@ struct Alibi
index_t x_total_, index_t x_total_,
AlibiMode mode_ = AlibiMode::VERTICAL) AlibiMode mode_ = AlibiMode::VERTICAL)
{ {
slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope; slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope_;
shift_left_up = [&]() { shift_left_up = [&]() {
if(RowMajor) if(RowMajor)
......
...@@ -76,7 +76,7 @@ struct FmhaFwdKernel ...@@ -76,7 +76,7 @@ struct FmhaFwdKernel
return n.empty() ? n : std::string("p") + n; }(); return n.empty() ? n : std::string("p") + n; }();
return return
_SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) + _SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" + "_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
...@@ -702,7 +702,7 @@ struct FmhaFwdKernel ...@@ -702,7 +702,7 @@ struct FmhaFwdKernel
else else
{ {
return Alibi<SaccDataType, true>{ return Alibi<SaccDataType, true>{
slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL}; slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
} }
} }
else else
......
...@@ -18,10 +18,12 @@ struct FmhaFwdTilePartitioner ...@@ -18,10 +18,12 @@ struct FmhaFwdTilePartitioner
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, static constexpr const char* name = "shb";
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t hdim_v_) ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{ {
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
...@@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner ...@@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner
} }
}; };
template <typename BlockFmhaShape_>
using FmhaFwdTilePartitioner_SHB = FmhaFwdTilePartitioner<BlockFmhaShape_>;
template <typename BlockFmhaShape_>
struct FmhaFwdTilePartitioner_HBS
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
static constexpr const char* name = "hbs";
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(nhead_,
batch_size_,
ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1));
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.z;
const index_t i_nhead = blockIdx.x;
const index_t i_batch = blockIdx.y;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
};
} // namespace ck_tile } // namespace ck_tile
...@@ -2,9 +2,14 @@ ...@@ -2,9 +2,14 @@
set(GEMM_MULTI_ABD_INSTANCES) set(GEMM_MULTI_ABD_INSTANCES)
list(APPEND GEMM_MULTI_ABD_INSTANCES list(APPEND GEMM_MULTI_ABD_INSTANCES
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
) )
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ELayout,
AsDataType,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
EDataType,
AElementOp,
Multiply,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
PassThrough,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
PassThrough,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
Multiply,
Add>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
Multiply,
Add,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
Multiply,
Add,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
...@@ -52,112 +52,6 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( ...@@ -52,112 +52,6 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
Interwave>{}); Interwave>{});
} }
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
Multiply,
Add>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
Multiply,
Add,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
Multiply,
Add,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ELayout,
AsDataType,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
EDataType,
AElementOp,
Multiply,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
PassThrough,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
PassThrough,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ELayout,
AsDataType,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
EDataType,
AElementOp,
Multiply,
FastGelu>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
FastGelu,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
FastGelu,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ELayout,
AsDataType,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
EDataType,
AElementOp,
Multiply,
FastGelu>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
FastGelu,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
FastGelu,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
EDataType,
AElementOp,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
Multiply,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
Multiply,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
EDataType,
AElementOp,
PassThrough,
MultiplyAdd>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
PassThrough,
MultiplyAdd,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
PassThrough,
MultiplyAdd,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
...@@ -52,111 +52,6 @@ void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_i ...@@ -52,111 +52,6 @@ void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_i
Interwave>{}); Interwave>{});
} }
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
EDataType,
AElementOp,
PassThrough,
MultiplyAdd>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
PassThrough,
MultiplyAdd,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
PassThrough,
MultiplyAdd,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
EDataType,
AElementOp,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
Multiply,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
Multiply,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
EDataType,
AElementOp,
PassThrough,
MultiplyFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
MultiplyFastGelu,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
MultiplyFastGelu,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
EDataType,
AElementOp,
PassThrough,
MultiplyFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
MultiplyFastGelu,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
MultiplyFastGelu,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
[build-system]
requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"
[project]
name = "rocm-composable-kernel"
dynamic = ["version"]
description = "Composable Kernel, performance-critical kernels for machine learning workloads"
readme = "README.md"
requires-python = ">=3.8"
license = {file = "LICENSE"}
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = []
[project.urls]
"Homepage" = "https://github.com/rocm/composable_kernel"
"Bug Tracker" = "https://github.com/rocm/composable_kernel/issues"
[tool.setuptools]
packages = ["ck4inductor", "ck4inductor.include", "ck4inductor.library"]
[tool.setuptools.package-dir]
ck4inductor = "python/ck4inductor"
"ck4inductor.include" = "include"
"ck4inductor.library" = "library"
[tool.setuptools.package-data]
"ck4inductor.include" = ["ck/**/*.hpp"]
"ck4inductor.library" = ["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"]
[tool.setuptools.dynamic]
version = { attr = "setuptools_scm.get_version" }
import logging
import os
import subprocess
from dataclasses import fields, replace
from functools import lru_cache, partial
from typing import List
from ..util import library_path
from .op import CKGemmOperation
log = logging.getLogger(__name__)
def _ck_library_dir():
gemm_instances_path = os.path.join(
library_path(), "src", "tensor_operation_instance", "gpu", "gemm_universal"
)
if not os.path.exists(gemm_instances_path):
log.error("CK library path %s does not exist", gemm_instances_path)
return None
return gemm_instances_path
def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
"""
Parse the lines containing Universal Gemm template instances into `CKGemmOperation` instances
"""
def maybe_int(s):
try:
return int(s)
except ValueError:
return s
op_instances = []
for line in str_instances:
s_template_args = line.split("DeviceGemm_Xdl_CShuffleV3")[-1].strip("<>, ")
template_args = []
i_current = 0
while i_current < len(s_template_args):
if s_template_args[i_current] == " ":
# skip whitespace
i_current += 1
continue
elif s_template_args[i_current : i_current + 2] == "S<":
# parse template S<Index...>
i_next = s_template_args.find(">", i_current)
template_args.append(
tuple(map(int, s_template_args[i_current + 2 : i_next].split(",")))
)
i_current = i_next + 2
else:
# all string attributes must be either type aliases or global constants in C++
i_next = s_template_args.find(",", i_current)
template_args.append(
maybe_int(
s_template_args[i_current : i_next if i_next != -1 else None]
)
)
if i_next != -1:
i_current = i_next + 1
if i_next == -1:
break
# pad with `None`s for the fields which are not defined in the instance
new_instance = CKGemmOperation(
*template_args, # type: ignore[arg-type]
*((None,) * (len(fields(CKGemmOperation)) - len(template_args))),
)
# the last 2 template parameters are optional
# if they are absent, substitute them with default values from Universal Gemm C++ template declaration
if new_instance.a_compute_dtype is None:
new_instance.a_compute_dtype = new_instance.c_element_dtype
if new_instance.b_compute_dtype is None:
new_instance.b_compute_dtype = new_instance.c_element_dtype
op_instances.append(new_instance)
return op_instances
def default_instances() -> List[CKGemmOperation]:
# fallback: known working op instance for problem size M=2240 K=256 N=2048
# all string attributes must be either type aliases or global constants in C++
return [
CKGemmOperation(
a_layout="Row",
b_layout="Row",
c_layout="Row",
a_element_dtype="F16",
b_element_dtype="F16",
c_element_dtype="F16",
a_compute_dtype="F16",
b_compute_dtype="F16",
acc_dtype="F32",
c_shuffle_dtype="F16",
a_elementwise_op="PassThrough",
b_elementwise_op="PassThrough",
c_elementwise_op="PassThrough",
gemm_specialization="GemmSpecialization::Default",
block_size=256,
m_per_block=224,
n_per_block=256,
k_per_block=64,
a_k1=8,
b_k1=2,
m_per_xdl=16,
n_per_xdl=16,
m_xdl_per_wave=7,
n_xdl_per_wave=8,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1),
a_block_transfer_thread_cluster_arrange_order=(1, 0, 2),
a_block_transfer_src_access_order=(1, 0, 2),
a_block_transfer_src_vector_dim=2,
a_block_transfer_src_scalar_per_vector=8,
a_block_transfer_dst_scalar_per_vector_ak1=8,
a_block_lds_extra_m=0, # type: ignore[arg-type]
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1),
b_block_transfer_thread_cluster_arrange_order=(0, 2, 1),
b_block_transfer_src_access_order=(0, 2, 1),
b_block_transfer_src_vector_dim=1,
b_block_transfer_src_scalar_per_vector=8,
b_block_transfer_dst_scalar_per_vector_bk1=2,
b_block_lds_extra_n=0, # type: ignore[arg-type]
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=2,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
1,
32,
1,
8,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
)
]
@lru_cache(None)
def gen_ops_library() -> List[CKGemmOperation]:
"""
Parse the Universal Gemm instances defined in the composable kernel library folder.
"""
ck_library_dir = _ck_library_dir()
if not ck_library_dir:
return []
grep_result = subprocess.run(
[
"grep",
"-inR",
"DeviceGemm_Xdl_CShuffleV3",
_ck_library_dir(),
],
capture_output=True,
text=True,
)
op_instances = parse_instances(grep_result.stdout.strip().split("\n"))
log.debug("ck instances from library: %d", len(op_instances))
schedulers = [
"BlockGemmPipelineScheduler::Intrawave",
"BlockGemmPipelineScheduler::Interwave",
]
gemm_specs = [
"GemmSpecialization::Default",
"GemmSpecialization::MPadding",
"GemmSpecialization::NPadding",
"GemmSpecialization::KPadding",
"GemmSpecialization::MNPadding",
"GemmSpecialization::MKPadding",
"GemmSpecialization::NKPadding",
"GemmSpecialization::MNKPadding",
]
# substitute templated args by looping through their domains
substitute_instances = []
for instance in op_instances:
sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched"
sub_spec = instance.gemm_specialization == "GemmSpec"
schedulers_range = (
schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler]
)
spec_range = gemm_specs if sub_spec else [instance.gemm_specialization]
for scheduler in schedulers_range:
for spec in spec_range:
substitute_instances.append(
replace(
instance,
block_gemm_pipeline_scheduler=scheduler,
gemm_specialization=spec,
)
)
return substitute_instances
@lru_cache(None)
def gen_ops_preselected() -> List[CKGemmOperation]:
"""
Manually selected (through benchmarking) F16/F16/F16 Row/Col/Row instances
"""
ck_gemm_f16_rcr = partial(
CKGemmOperation,
a_layout="Row",
b_layout="Col",
c_layout="Row",
a_element_dtype="F16",
b_element_dtype="F16",
c_element_dtype="F16",
acc_dtype="F32",
c_shuffle_dtype="F16",
a_elementwise_op="PassThrough",
b_elementwise_op="PassThrough",
c_elementwise_op="PassThrough",
k_per_block=64,
a_k1=8,
b_k1=8,
a_block_transfer_thread_cluster_arrange_order=(1, 0, 2),
a_block_transfer_src_access_order=(1, 0, 2),
a_block_transfer_src_vector_dim=2,
a_block_transfer_src_scalar_per_vector=8,
a_block_transfer_dst_scalar_per_vector_ak1=8,
a_block_lds_extra_m=0,
b_block_transfer_thread_cluster_arrange_order=(1, 0, 2),
b_block_transfer_src_access_order=(1, 0, 2),
b_block_transfer_src_vector_dim=2,
b_block_transfer_src_scalar_per_vector=8,
b_block_transfer_dst_scalar_per_vector_bk1=8,
b_block_lds_extra_n=0,
a_compute_dtype="F16",
b_compute_dtype="F16",
)
ck_gemm_f16_rcr_compute_friendly = partial(
ck_gemm_f16_rcr,
block_size=256,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1),
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1),
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
1,
32,
1,
8,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
)
ck_gemm_f16_rcr_memory_friendly = partial(
ck_gemm_f16_rcr,
block_size=128,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1),
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1),
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Interwave",
block_gemm_pipeline_version="BlockGemmPipelineVersion::v2",
)
ck_gemm_f16_rcr_latency_friendly = partial(
ck_gemm_f16_rcr,
gemm_specialization="GemmSpecialization::Default",
block_size=128,
m_per_xdl=16,
n_per_xdl=16,
m_xdl_per_wave=1,
n_xdl_per_wave=1,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1),
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1),
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
block_gemm_pipeline_version="BlockGemmPipelineVersion::v1",
)
return [
ck_gemm_f16_rcr_compute_friendly(
gemm_specialization="GemmSpecialization::MNKPadding",
m_per_block=224,
n_per_block=256,
m_per_xdl=16,
n_per_xdl=16,
m_xdl_per_wave=7,
n_xdl_per_wave=8,
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=2,
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
),
ck_gemm_f16_rcr_compute_friendly(
gemm_specialization="GemmSpecialization::MNKPadding",
m_per_block=128,
n_per_block=128,
m_per_xdl=32,
n_per_xdl=32,
m_xdl_per_wave=2,
n_xdl_per_wave=2,
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
),
ck_gemm_f16_rcr_compute_friendly(
gemm_specialization="GemmSpecialization::MNKPadding",
m_per_block=128,
n_per_block=128,
m_per_xdl=32,
n_per_xdl=32,
m_xdl_per_wave=2,
n_xdl_per_wave=2,
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
block_gemm_pipeline_version="BlockGemmPipelineVersion::v4",
),
ck_gemm_f16_rcr_compute_friendly(
gemm_specialization="GemmSpecialization::MNKPadding",
m_per_block=128,
n_per_block=128,
m_per_xdl=32,
n_per_xdl=32,
m_xdl_per_wave=2,
n_xdl_per_wave=2,
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
block_gemm_pipeline_version="BlockGemmPipelineVersion::v5",
),
ck_gemm_f16_rcr_compute_friendly(
gemm_specialization="GemmSpecialization::Default",
m_per_block=128,
n_per_block=128,
m_per_xdl=32,
n_per_xdl=32,
m_xdl_per_wave=2,
n_xdl_per_wave=2,
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
),
ck_gemm_f16_rcr_compute_friendly(
gemm_specialization="GemmSpecialization::Default",
m_per_block=128,
n_per_block=128,
m_per_xdl=32,
n_per_xdl=32,
m_xdl_per_wave=2,
n_xdl_per_wave=2,
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
block_gemm_pipeline_version="BlockGemmPipelineVersion::v4",
),
ck_gemm_f16_rcr_compute_friendly(
gemm_specialization="GemmSpecialization::Default",
m_per_block=128,
n_per_block=128,
m_per_xdl=32,
n_per_xdl=32,
m_xdl_per_wave=2,
n_xdl_per_wave=2,
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
block_gemm_pipeline_version="BlockGemmPipelineVersion::v5",
),
ck_gemm_f16_rcr_memory_friendly(
gemm_specialization="GemmSpecialization::Default",
m_per_block=16,
n_per_block=32,
m_per_xdl=16,
n_per_xdl=16,
m_xdl_per_wave=1,
n_xdl_per_wave=1,
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
1,
16,
1,
8,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
),
ck_gemm_f16_rcr_memory_friendly(
gemm_specialization="GemmSpecialization::MNKPadding",
m_per_block=16,
n_per_block=32,
m_per_xdl=16,
n_per_xdl=16,
m_xdl_per_wave=1,
n_xdl_per_wave=1,
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
1,
16,
1,
8,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
),
ck_gemm_f16_rcr_memory_friendly(
gemm_specialization="GemmSpecialization::MNKPadding",
m_per_block=16,
n_per_block=64,
m_per_xdl=16,
n_per_xdl=16,
m_xdl_per_wave=1,
n_xdl_per_wave=2,
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=2,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
1,
16,
1,
8,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
),
ck_gemm_f16_rcr_memory_friendly(
gemm_specialization="GemmSpecialization::MNKPadding",
m_per_block=32,
n_per_block=64,
m_per_xdl=32,
n_per_xdl=32,
m_xdl_per_wave=1,
n_xdl_per_wave=1,
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
1,
16,
1,
8,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
),
ck_gemm_f16_rcr_memory_friendly(
gemm_specialization="GemmSpecialization::MNKPadding",
m_per_block=32,
n_per_block=128,
m_per_xdl=32,
n_per_xdl=32,
m_xdl_per_wave=1,
n_xdl_per_wave=2,
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
1,
16,
1,
8,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
),
ck_gemm_f16_rcr_memory_friendly(
gemm_specialization="GemmSpecialization::Default",
m_per_block=32,
n_per_block=16,
m_per_xdl=16,
n_per_xdl=16,
m_xdl_per_wave=1,
n_xdl_per_wave=1,
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
1,
32,
1,
4,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
),
ck_gemm_f16_rcr_memory_friendly(
gemm_specialization="GemmSpecialization::MNKPadding",
m_per_block=32,
n_per_block=16,
m_per_xdl=16,
n_per_xdl=16,
m_xdl_per_wave=1,
n_xdl_per_wave=1,
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
1,
32,
1,
4,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
),
ck_gemm_f16_rcr_memory_friendly(
gemm_specialization="GemmSpecialization::MNKPadding",
m_per_block=64,
n_per_block=16,
m_per_xdl=16,
n_per_xdl=16,
m_xdl_per_wave=2,
n_xdl_per_wave=1,
c_shuffle_m_xdl_per_wave_per_shuffle=2,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
1,
64,
1,
2,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
),
ck_gemm_f16_rcr_memory_friendly(
gemm_specialization="GemmSpecialization::MNKPadding",
m_per_block=64,
n_per_block=32,
m_per_xdl=32,
n_per_xdl=32,
m_xdl_per_wave=1,
n_xdl_per_wave=1,
c_shuffle_m_xdl_per_wave_per_shuffle=1,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
1,
32,
1,
4,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
),
ck_gemm_f16_rcr_memory_friendly(
gemm_specialization="GemmSpecialization::MNKPadding",
m_per_block=128,
n_per_block=32,
m_per_xdl=32,
n_per_xdl=32,
m_xdl_per_wave=2,
n_xdl_per_wave=1,
c_shuffle_m_xdl_per_wave_per_shuffle=2,
c_shuffle_n_xdl_per_wave_per_shuffle=1,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
1,
32,
1,
4,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
),
ck_gemm_f16_rcr_latency_friendly(
m_per_block=16,
n_per_block=32,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
1,
16,
1,
8,
),
),
ck_gemm_f16_rcr_latency_friendly(
m_per_block=32,
n_per_block=16,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
1,
32,
1,
4,
),
),
]
if __name__ == "__main__":
print(gen_ops_library())
from dataclasses import asdict, dataclass
from typing import Optional, Tuple
@dataclass
class CKGemmOperation:
"""
A python dataclass storing the template parameters of a CK Universal Gemm template instance
"""
a_layout: str
b_layout: str
c_layout: str
a_element_dtype: str
b_element_dtype: str
c_element_dtype: str
acc_dtype: str
c_shuffle_dtype: str
a_elementwise_op: str
b_elementwise_op: str
c_elementwise_op: str
gemm_specialization: str
block_size: int
m_per_block: int
n_per_block: int
k_per_block: int
a_k1: int
b_k1: int
m_per_xdl: int
n_per_xdl: int
m_xdl_per_wave: int
n_xdl_per_wave: int
a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int]
a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
a_block_transfer_src_access_order: Tuple[int, int, int]
a_block_transfer_src_vector_dim: int
a_block_transfer_src_scalar_per_vector: int
a_block_transfer_dst_scalar_per_vector_ak1: int
a_block_lds_extra_m: bool
b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int]
b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
b_block_transfer_src_access_order: Tuple[int, int, int]
b_block_transfer_src_vector_dim: int
b_block_transfer_src_scalar_per_vector: int
b_block_transfer_dst_scalar_per_vector_bk1: int
b_block_lds_extra_n: bool
c_shuffle_m_xdl_per_wave_per_shuffle: int
c_shuffle_n_xdl_per_wave_per_shuffle: int
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: (
Tuple[int, int, int, int]
)
c_shuffle_block_transfer_scalar_per_vector_n_per_block: int
block_gemm_pipeline_scheduler: str
block_gemm_pipeline_version: Optional[str]
a_compute_dtype: Optional[str]
b_compute_dtype: Optional[str]
def name(self):
# cpp alias for template instance
return f"ck_devicegemm_xdl_shuffle_v3_{self.key_name()}"
def key_name(self):
# TBD; must be unique per instance. Intended to use as dict key
return "_".join(
[
"K"
+ field_name.replace("_", "").lower()
+ "V"
+ (
"x".join(map(str, iter(field_value)))
if isinstance(field_value, tuple)
else str(field_value).replace(":", "")
)
for field_name, field_value in self.dict_items()
]
)
def dict_items(self):
return asdict(self).items()
import functools
import os
@functools.lru_cache(None)
def library_path():
return os.path.join(os.path.dirname(__file__), 'library')
...@@ -131,74 +131,74 @@ int main() ...@@ -131,74 +131,74 @@ int main()
0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5}); 0, 1, 2, 3, 4, 5});
rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5,
1, 0, 1, 2, 3, 4, -1, 0, -1, -2, -3, -4,
2, 1, 0, 1, 2, 3, -2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2}); -3, -2, -1, 0, -1, -2});
rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0, -3, -2, -1, 0,
4, 3, 2, 1, -4, -3, -2, -1,
5, 4, 3, 2}); -5, -4, -3, -2});
rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2, -3, -2, -1, 0, -1, -2,
4, 3, 2, 1, 0, 1, -4, -3, -2, -1, 0, -1,
5, 4, 3, 2, 1, 0}); -5, -4, -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5,
1, 2, 3, 4, -1, -2, -3, -4,
0, 1, 2, 3, 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0}); -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5}); 0, 1, 2, 3, 4, 5});
rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5,
1, 0, 1, 2, 3, 4, -1, 0, -1, -2, -3, -4,
2, 1, 0, 1, 2, 3, -2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2}); -3, -2, -1, 0, -1, -2});
rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0, -3, -2, -1, 0,
4, 3, 2, 1, -4, -3, -2, -1,
5, 4, 3, 2}); -5, -4, -3, -2});
rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2, -3, -2, -1, 0, -1, -2,
4, 3, 2, 1, 0, 1, -4, -3, -2, -1, 0, -1,
5, 4, 3, 2, 1, 0}); -5, -4, -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5,
1, 2, 3, 4, -1, -2, -3, -4,
0, 1, 2, 3, 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0}); -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_slope_generation<float>(8, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625}); rtn &= test_alibi_slope_generation<float>(8, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625});
rtn &= test_alibi_slope_generation<float>(16, {0.7071067811865476, 0.5, 0.35355339059327384, 0.25000000000000006, 0.17677669529663692, rtn &= test_alibi_slope_generation<float>(16, {0.7071067811865476, 0.5, 0.35355339059327384, 0.25000000000000006, 0.17677669529663692,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment