Unverified Commit f9466a75 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into rimadduri/grouped_gemm_async_memcpy

parents 89d8fca1 44828b7c
...@@ -124,7 +124,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -124,7 +124,7 @@ struct GemmPipelineAGmemBGmemCRegV1
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}); b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM // Block GEMM
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>(); auto block_gemm = Policy::template GetBlockGemm<Problem>();
// Acc register tile // Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
import logging
import os
import subprocess
from dataclasses import replace
from functools import lru_cache
from typing import List
from ..util import library_path
from .op import CKBatchedGemmOperation
log = logging.getLogger(__name__)
def _ck_library_dir():
gemm_instances_path = os.path.join(
library_path(),
"src",
"tensor_operation_instance",
"gpu",
"gemm_universal_batched",
)
if not os.path.exists(gemm_instances_path):
log.error("CK library path %s does not exist", gemm_instances_path)
return None
return gemm_instances_path
def parse_instances(str_instances: List[str]) -> List[CKBatchedGemmOperation]:
"""
Parse the lines containing Universal Gemm template instances into `CKBatchedGemmOperation` instances
"""
def maybe_int(s):
try:
return int(s)
except ValueError:
return s
op_instances = []
for line in str_instances:
s_template_args = line.split("DeviceBatchedGemmMultiD_Xdl_CShuffle_V3")[
-1
].strip("<>, ")
template_args = []
i_current = 0
while i_current < len(s_template_args):
if s_template_args[i_current] == " ":
# skip whitespace
i_current += 1
continue
elif s_template_args[i_current : i_current + 2] == "S<":
# parse template S<Index...>
i_next = s_template_args.find(">", i_current)
template_args.append(
tuple(map(int, s_template_args[i_current + 2 : i_next].split(",")))
)
i_current = i_next + 2
else:
# all string attributes must be either type aliases or global constants in C++
i_next = s_template_args.find(",", i_current)
template_args.append(
maybe_int(
s_template_args[i_current : i_next if i_next != -1 else None]
)
)
if i_next != -1:
i_current = i_next + 1
if i_next == -1:
break
# ds layout and dtype are parsed as placeholder; reset value
template_args[2] = tuple() # ds layout
template_args[6] = tuple() # ds dtype
new_instance = CKBatchedGemmOperation(
*template_args, # type: ignore[arg-type]
)
op_instances.append(new_instance)
return op_instances
@lru_cache(None)
def gen_ops_library() -> List[CKBatchedGemmOperation]:
"""
Parse the Universal Gemm instances defined in the composable kernel library folder.
"""
ck_library_dir = _ck_library_dir()
if not ck_library_dir:
return []
grep_result = subprocess.run(
[
"grep",
"-inR",
"DeviceBatchedGemmMultiD_Xdl_CShuffle_V3",
_ck_library_dir(),
],
capture_output=True,
text=True,
)
op_instances = parse_instances(grep_result.stdout.strip().split("\n"))
log.debug("ck instances from library: %d", len(op_instances))
schedulers = [
"BlockGemmPipelineScheduler::Intrawave",
"BlockGemmPipelineScheduler::Interwave",
]
gemm_specs = [
"GemmSpecialization::Default",
"GemmSpecialization::MPadding",
"GemmSpecialization::NPadding",
"GemmSpecialization::KPadding",
"GemmSpecialization::MNPadding",
"GemmSpecialization::MKPadding",
"GemmSpecialization::NKPadding",
"GemmSpecialization::MNKPadding",
]
# substitute templated args by looping through their domains
substitute_instances = []
for instance in op_instances:
sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched"
sub_spec = instance.gemm_specialization == "GemmSpec"
schedulers_range = (
schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler]
)
spec_range = gemm_specs if sub_spec else [instance.gemm_specialization]
for scheduler in schedulers_range:
for spec in spec_range:
substitute_instances.append(
replace(
instance,
block_gemm_pipeline_scheduler=scheduler,
gemm_specialization=spec,
)
)
return substitute_instances
if __name__ == "__main__":
print(gen_ops_library())
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
from dataclasses import asdict, dataclass
from typing import Optional, Tuple
@dataclass
class CKBatchedGemmOperation:
"""
A python dataclass storing the template parameters of a CK Universal Gemm template instance
"""
a_layout: str
b_layout: str
ds_layouts: Tuple[str] # addmm specific
c_layout: str
a_element_dtype: str
b_element_dtype: str
ds_element_dtypes: Tuple[str] # addmm specific
c_element_dtype: str
acc_dtype: str
c_shuffle_dtype: str
a_elementwise_op: str
b_elementwise_op: str
c_elementwise_op: str
gemm_specialization: str
block_size: int
m_per_block: int
n_per_block: int
k_per_block: int
a_k1: int
b_k1: int
m_per_xdl: int
n_per_xdl: int
m_xdl_per_wave: int
n_xdl_per_wave: int
a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int]
a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
a_block_transfer_src_access_order: Tuple[int, int, int]
a_block_transfer_src_vector_dim: int
a_block_transfer_src_scalar_per_vector: int
a_block_transfer_dst_scalar_per_vector_ak1: int
a_block_lds_extra_m: bool
b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int]
b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
b_block_transfer_src_access_order: Tuple[int, int, int]
b_block_transfer_src_vector_dim: int
b_block_transfer_src_scalar_per_vector: int
b_block_transfer_dst_scalar_per_vector_bk1: int
b_block_lds_extra_n: bool
c_shuffle_m_xdl_per_wave_per_shuffle: int
c_shuffle_n_xdl_per_wave_per_shuffle: int
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: (
Tuple[int, int, int, int]
)
c_shuffle_block_transfer_scalar_per_vector_n_per_block: Tuple[int]
block_gemm_pipeline_scheduler: str
block_gemm_pipeline_version: str
a_compute_dtype: Optional[str] = None
b_compute_dtype: Optional[str] = None
def name(self):
# cpp alias for template instance
return f"ck_device_batched_gemm_multi_d_xdl_c_shuffle_v3_{self.key_name()}"
def key_name(self):
# TBD; must be unique per instance. Intended to use as dict key
return "_".join(
[
"K"
+ field_name.replace("_", "").lower()
+ "V"
+ (
"x".join(map(str, iter(field_value)))
if isinstance(field_value, tuple)
else str(field_value).replace(":", "")
)
for field_name, field_value in self.dict_items()
]
)
def dict_items(self):
return asdict(self).items()
...@@ -130,9 +130,7 @@ def gen_conv_ops_library() -> List[CKGroupedConvFwdOp]: ...@@ -130,9 +130,7 @@ def gen_conv_ops_library() -> List[CKGroupedConvFwdOp]:
# substitute templated args by looping through their domains # substitute templated args by looping through their domains
substitute_instances = [] substitute_instances = []
for instance in op_instances: for instance in op_instances:
sub_scheduler = ( sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched"
instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched"
)
sub_spec = instance.conv_forward_specialization == "ConvSpec" sub_spec = instance.conv_forward_specialization == "ConvSpec"
schedulers_range = ( schedulers_range = (
schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler] schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler]
......
add_subdirectory(image_to_column) add_subdirectory(image_to_column)
add_subdirectory(gemm) add_subdirectory(gemm)
add_subdirectory(batched_gemm)
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_ck_tile_batched_gemm test_batched_gemm.cpp)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_batched_gemm_util.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileBatchedGemm, KernelTypes);
#include "test_batched_gemm_ut_cases.inc"
#pragma once
TYPED_TEST(TestCkTileBatchedGemm, Basic)
{
constexpr int M = 256;
constexpr int N = 128;
constexpr int K = 128;
this->Run(M, N, K);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
template <typename Tuple>
class TestCkTileBatchedGemm : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs
{
};
template <typename ALayout, typename BLayout, typename CLayout>
void invoke_batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1;
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadM,
kPadN,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using Kernel =
ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
public:
void Run(const int M,
const int N,
const int K,
int StrideA = 128,
int StrideB = 128,
int StrideC = 128,
const int BatchStrideA = 32768,
const int BatchStrideB = 16384,
const int BatchStrideC = 32768,
const int BatchCount = 16)
{
using namespace ck_tile::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{batch_stride, stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{batch_stride, 1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout),
ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
ck_tile::HostTensor<BDataType> b_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{}));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
batched_gemm_kargs kargs{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
BatchCount};
invoke_batched_gemm<ALayout, BLayout, CLayout>(kargs,
ck_tile::stream_config{nullptr, false});
std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << StrideA << " StrideB =" << StrideB << " StrideC =" << StrideC
<< " BatchStrideA =" << BatchStrideA << " BatchStrideB =" << BatchStrideB
<< " BatchStrideC =" << BatchStrideC << " BatchCount =" << BatchCount
<< std::endl;
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
c_m_n_host_ref.SetZero();
const auto b_n_k = b_k_n.transpose({0, 2, 1});
ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_n_k, c_m_n_host_ref);
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
EXPECT_TRUE(pass);
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment