Commit 4d914af3 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 223a2abe 4b798833
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
from dataclasses import asdict, dataclass
from typing import Optional, Tuple
@dataclass
class CKGroupedConvFwdOp:
n_dim_spatial: int
a_layout: str
b_layout: str
ds_layout: Tuple[str]
e_layout: str
a_element_dtype: str
b_element_dtype: str
acc_dtype: str
c_shuffle_dtype: str
ds_element_dtype: Tuple[str]
e_element_dtype: str
a_elementwise_op: str
b_elementwise_op: str
cde_elementwise_op: str
conv_forward_specialization: str
gemm_specialization: str
block_size: int
m_per_block: int
n_per_block: int
k_per_block: int
ak1: int
bk1: int
m_per_xdl: int
n_per_xdl: int
m_xdl_per_wave: int
n_xdl_per_wave: int
a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int]
a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
a_block_transfer_src_access_order: Tuple[int, int, int]
a_block_transfer_src_vector_dim: int
a_block_transfer_src_scalar_per_vector: int
a_block_transfer_dst_scalar_per_vector_ak1: int
a_block_lds_extra_m: bool
b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int]
b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
b_block_transfer_src_access_order: Tuple[int, int, int]
b_block_transfer_src_vector_dim: int
b_block_transfer_src_scalar_per_vector: int
b_block_transfer_dst_scalar_per_vector_bk1: int
b_block_lds_extra_n: bool
c_shuffle_m_xdl_per_wave_per_shuffle: int
c_shuffle_n_xdl_per_wave_per_shuffle: int
cde_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: Tuple[ # noqa
int,
int,
int,
int,
]
cde_block_transfer_scalar_per_vector_n_per_block: int
block_gemm_pipeline_scheduler: str
block_gemm_pipeline_version: str
a_compute_dtype: Optional[str] = None
b_compute_dtype: Optional[str] = None
def name(self):
# cpp alias for template instance
return (
f"ck_device_grouped_convolution_fwd_multiple_abd_xdl_c_shuffle_v3_"
f"{self.key_name()}"
)
def key_name(self):
# TBD; must be unique per instance. Intended to use as dict key
return "_".join(
[
"K"
+ field_name.replace("_", "").lower()
+ "V"
+ (
"x".join(map(str, iter(field_value)))
if isinstance(field_value, tuple)
else str(field_value).replace(":", "")
)
for field_name, field_value in self.dict_items()
]
)
def dict_items(self):
return asdict(self).items()
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
import logging import logging
import os import os
import subprocess import subprocess
from dataclasses import fields, replace from dataclasses import replace
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import List from typing import List
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
import functools import functools
import os import os
@functools.lru_cache(None) @functools.lru_cache(None)
def library_path(): def library_path():
return os.path.join(os.path.dirname(__file__), 'library') return os.path.join(os.path.dirname(__file__), "library")
...@@ -65,8 +65,9 @@ def parse_data_type(args): ...@@ -65,8 +65,9 @@ def parse_data_type(args):
if args.ck_profier_op == "grouped_conv_fwd": if args.ck_profier_op == "grouped_conv_fwd":
args.data_type = 3 args.data_type = 3
if args.data_type == "bfp16": if args.data_type == "bfp16":
if args.ck_profier_op == "grouped_conv_bwd_weight" or \ if args.ck_profier_op == "grouped_conv_bwd_weight":
args.ck_profier_op == "grouped_conv_bwd_data" or \ args.data_type = 5
if args.ck_profier_op == "grouped_conv_bwd_data" or \
args.ck_profier_op == "grouped_conv_fwd": args.ck_profier_op == "grouped_conv_fwd":
args.data_type = 2 args.data_type = 2
......
...@@ -210,3 +210,4 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL ...@@ -210,3 +210,4 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL
add_subdirectory(smfmac_op) add_subdirectory(smfmac_op)
endif() endif()
add_subdirectory(position_embedding) add_subdirectory(position_embedding)
add_subdirectory(scatter_gather)
add_subdirectory(image_to_column) add_subdirectory(image_to_column)
add_subdirectory(gemm)
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_ck_tile_gemm_mem_pipeline test_gemm_mem_pipeline.cpp)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_gemm_mem_pipeline_util.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Col, Row, F16, F16, F32, F16>,
std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
std::tuple< Col, Col, Row, F16, F16, F32, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGemmMemPipeline, KernelTypes);
#include "test_gemm_mem_pipeline_ut_cases.inc"
#pragma once
TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 1024;
constexpr int K = 320;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024;
constexpr int K = 320;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 1024;
constexpr int K = 432;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 1024;
constexpr int K = 512;
for(int M : Ms)
this->Run(M, N, K);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
template <typename Tuple>
class TestCkTileGemmMemPipeline : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
// TODO: expose tile size through test t-param ?
struct gemm_basic_args
{
const void* p_a;
const void* p_b;
void* p_c;
ck_tile::index_t kbatch;
ck_tile::index_t M;
ck_tile::index_t N;
ck_tile::index_t K;
ck_tile::index_t stride_A;
ck_tile::index_t stride_B;
ck_tile::index_t stride_C;
};
void invoke_gemm(const gemm_basic_args& args, const ck_tile::stream_config& s)
{
// TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr bool kPadA = true;
constexpr bool kPadB = true;
constexpr bool kPadC = true;
constexpr int kBlockPerCu = 1;
// ===============================================
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>;
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
Traits,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b,
args.p_c,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
{
std::cout << "Lunching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
};
if(has_hot_loop)
{
// Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Two)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Two>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
{
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Three>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
{
if(tail_num == ck_tile::TailNumber::Four)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Four>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Five>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
{
if(tail_num == ck_tile::TailNumber::Six)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Six>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Seven>{});
}
}
}
else
{
// Tail number always Full - #PrefetchStages
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else
{
std::ostringstream err;
err << "When there's no hot loop, this tail number \"" << tail_num
<< "\" is not supported! " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
}
public:
std::vector<int> k_batches_;
void SetUp() override { k_batches_ = {1}; }
void Run(const int M,
const int N,
const int K,
const int StrideA = 0,
const int StrideB = 0,
const int StrideC = 0)
{
for(auto kb : k_batches_)
{
RunSingle(M, N, K, StrideA, StrideB, StrideC, kb);
}
}
void RunSingle(const int M,
const int N,
const int K,
const int StrideA,
const int StrideB,
const int StrideC,
int kbatch = 1)
{
using namespace ck_tile::literals;
auto f_host_tensor_descriptor = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout),
ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
std::size_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{});
std::size_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{});
std::size_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k);
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
gemm_basic_args args;
args.p_a = a_m_k_dev_buf.GetDeviceBuffer();
args.p_b = b_k_n_dev_buf.GetDeviceBuffer();
args.p_c = c_m_n_dev_buf.GetDeviceBuffer();
args.kbatch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
invoke_gemm(args, ck_tile::stream_config{nullptr, false});
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref);
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
EXPECT_TRUE(pass);
}
};
add_test_executable(test_scatter_gather scatter_gather.cpp)
# target_compile_options(test_scatter_gather PRIVATE -v --save-temps -Wno-gnu-line-marker)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#ifndef TEST_SCATTER_GATHER_VERBOSE
#define TEST_SCATTER_GATHER_VERBOSE 1
#endif
#define HIP_CALL(call) \
do \
{ \
hipError_t err = call; \
if(err != hipSuccess) \
{ \
printf("[hiperror](%d) fail to call %s", static_cast<int>(err), #call); \
exit(0); \
} \
} while(0)
/*
TODO:
This is a simple design of scatter/gather through indexing transform, with limitations
We may design a scatter/gather adaptor layer directly inside tile window
*/
template <ck_tile::index_t ROW_TILE_SIZE = 8,
ck_tile::index_t COL_TILE_SIZE = 32 * 8,
ck_tile::index_t BLOCK_SIZE = 256,
ck_tile::index_t ALIGNMENT = 8,
typename INDEX_BUF_TYPE = ck_tile::index_t,
typename DATA_TYPE = ck_tile::fp16_t>
__global__ void row_scatter_gather(const INDEX_BUF_TYPE* src_row_idx_ptr,
const INDEX_BUF_TYPE* dst_row_idx_ptr,
const DATA_TYPE* src_ptr,
DATA_TYPE* dst_ptr,
ck_tile::index_t n_row_total,
ck_tile::index_t /*n_row_select*/,
ck_tile::index_t n_cols)
{
using namespace ck_tile;
// some constexpr vars
constexpr index_t vec = ALIGNMENT;
static_assert(COL_TILE_SIZE % vec == 0);
constexpr index_t col_lanes = COL_TILE_SIZE / vec;
constexpr index_t warp_size = ck_tile::get_warp_size();
static_assert(warp_size % col_lanes == 0);
constexpr index_t row_lanes = warp_size / col_lanes;
constexpr index_t num_warps = BLOCK_SIZE / warp_size;
static_assert(ROW_TILE_SIZE % (num_warps * row_lanes) == 0);
constexpr index_t row_repeat = ROW_TILE_SIZE / (num_warps * row_lanes);
static_assert(
row_repeat == 1,
"currently indexing not support(and would be not performant) if row_repeat has more");
// tile partitioner
index_t tile_col_idx = 0;
index_t tile_row_idx = blockIdx.x * ROW_TILE_SIZE;
// create our tild distribution, which tell us the location of different threads
constexpr auto src_dist = make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<row_repeat, num_warps, row_lanes>, sequence<col_lanes, vec>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
const auto coord = src_dist.calculate_index();
const auto row_coord = coord[number<0>{}] + tile_row_idx;
// load the current row index from the indexing buffer. we do not use ck_tile utility here
INDEX_BUF_TYPE src_row_id = src_row_idx_ptr[row_coord];
INDEX_BUF_TYPE dst_row_id = dst_row_idx_ptr[row_coord];
// printf("-- tid:%d, src_row_id:%d, dst_row_id:%d\n", static_cast<int>(threadIdx.x),
// static_cast<int>(src_row_id), static_cast<int>(dst_row_id));
const auto src_view =
make_naive_tensor_view<address_space_enum::global>(src_ptr,
make_tuple(n_row_total, n_cols),
make_tuple(n_cols, 1),
number<vec>{}, // alignement
number<1>{});
const auto src_gather_view = transform_tensor_view(
src_view,
make_tuple(make_indexing_transform(
n_row_total,
src_row_id), // here we replace row_idx which is loaded from another buffer
make_pass_through_transform(n_cols)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto src_tile = make_tile_window(src_gather_view,
make_tuple(number<ROW_TILE_SIZE>{}, number<COL_TILE_SIZE>{}),
{tile_row_idx, tile_col_idx},
src_dist);
const auto dst_view =
make_naive_tensor_view<address_space_enum::global>(dst_ptr,
make_tuple(n_row_total, n_cols),
make_tuple(n_cols, 1),
number<vec>{},
number<1>{});
const auto dst_scatter_view = transform_tensor_view(
dst_view,
make_tuple(make_indexing_transform(
n_row_total,
dst_row_id), // here we replace row_idx which is loaded from another buffer
make_pass_through_transform(n_cols)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto dst_tile = make_tile_window(dst_scatter_view,
make_tuple(number<ROW_TILE_SIZE>{}, number<COL_TILE_SIZE>{}),
{tile_row_idx, tile_col_idx},
src_dist /*reuse distribution*/);
// we finished descriptor construction and index calculation, now start load/store
for(auto i = 0; i < n_cols; i += COL_TILE_SIZE)
{
// note that scatter/gather are just the same API when doing load store as normal memory
// operation
auto data = load_tile(src_tile);
store_tile(dst_tile, data);
move_tile_window(src_tile, {number<0>{}, number<COL_TILE_SIZE>{}});
move_tile_window(dst_tile, {number<0>{}, number<COL_TILE_SIZE>{}});
}
}
union pixel
{
struct __attribute__((packed))
{
unsigned int r : 6;
unsigned int c : 10;
};
ushort data;
};
struct unique_linear_rand
{
unique_linear_rand(int capacity_) : capacity(capacity_) {}
std::unordered_set<int> set;
int gen()
{
if(static_cast<int>(set.size()) >= capacity)
{
printf("overflow, but will give you an number as well\n");
return std::rand() % capacity;
}
while(1)
{
int r = std::rand() % capacity;
if(set.count(r) == 1)
{
continue;
}
set.insert(r);
return r;
}
}
int capacity;
};
int main()
{
int row_total = 64;
int row_select = 8 * 2;
int col = 256 * 2;
using fp16_t = ck_tile::fp16_t;
constexpr int row_tile = 8;
constexpr int col_tile = 256;
fp16_t* src = reinterpret_cast<fp16_t*>(malloc(row_total * col * sizeof(fp16_t)));
for(int i_r = 0; i_r < row_total; i_r++)
{
for(int i_c = 0; i_c < col; i_c++)
{
int i = i_r * col + i_c;
pixel p;
p.r = i_r;
p.c = i_c;
ushort d = p.data;
src[i] = ck_tile::bit_cast<fp16_t>(d); // for simplicity, just cast
}
}
fp16_t* dst = reinterpret_cast<fp16_t*>(malloc(row_total * col * sizeof(fp16_t)));
int* src_idx = reinterpret_cast<int*>(malloc(row_select * sizeof(int)));
int* dst_idx = reinterpret_cast<int*>(malloc(row_select * sizeof(int)));
// std::srand(std::time(std::nullptr));
// std::srand(11935);
std::srand(std::time(nullptr));
auto src_gen = unique_linear_rand(row_total);
auto dst_gen = unique_linear_rand(row_total); // dst index must be unique. src is fine
for(int i_r = 0; i_r < row_select; i_r++)
{
src_idx[i_r] = src_gen.gen();
dst_idx[i_r] = dst_gen.gen();
}
void* dev_src;
void* dev_dst;
void* dev_src_idx;
void* dev_dst_idx;
HIP_CALL(hipMalloc(&dev_src, row_total * col * sizeof(fp16_t)));
HIP_CALL(hipMalloc(&dev_dst, row_total * col * sizeof(fp16_t)));
HIP_CALL(hipMalloc(&dev_src_idx, row_select * sizeof(int)));
HIP_CALL(hipMalloc(&dev_dst_idx, row_select * sizeof(int)));
HIP_CALL(hipMemcpy(dev_src, src, row_total * col * sizeof(fp16_t), hipMemcpyHostToDevice));
HIP_CALL(hipMemcpy(dev_src_idx, src_idx, row_select * sizeof(int), hipMemcpyHostToDevice));
HIP_CALL(hipMemcpy(dev_dst_idx, dst_idx, row_select * sizeof(int), hipMemcpyHostToDevice));
constexpr int bdim = 256;
int gdim = (row_select + row_tile - 1) / row_tile;
row_scatter_gather<row_tile, col_tile><<<gdim, bdim>>>(reinterpret_cast<int*>(dev_src_idx),
reinterpret_cast<int*>(dev_dst_idx),
reinterpret_cast<fp16_t*>(dev_src),
reinterpret_cast<fp16_t*>(dev_dst),
row_total,
row_select,
col);
HIP_CALL(hipMemcpy(dst, dev_dst, row_total * col * sizeof(fp16_t), hipMemcpyDeviceToHost));
#if TEST_SCATTER_GATHER_VERBOSE
printf("select row:");
for(int i_r = 0; i_r < row_select; i_r++)
{
printf("%d->%d->%d ", i_r, src_idx[i_r], dst_idx[i_r]);
}
printf("\n");
#endif
int err_cnt = 0;
for(int i_r = 0; i_r < row_select; i_r++)
{
for(int i_c = 0; i_c < col; i_c++)
{
int i = dst_idx[i_r] * col + i_c;
pixel p = ck_tile::bit_cast<pixel>(dst[i]);
bool is_ok = p.r == src_idx[i_r] && p.c == i_c;
if(!is_ok)
{
if(i_c == 0)
printf("(%d)pixel: %dx%d -> %d\n", i_r, p.r, p.c, dst_idx[i_r]);
err_cnt++;
}
}
}
#if TEST_SCATTER_GATHER_VERBOSE
printf("err:%d\n", err_cnt);
#endif
free(src);
free(dst);
free(src_idx);
free(dst_idx);
return err_cnt == 0 ? 0 : -1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment