Commit 565311a8 authored by ThomasNing's avatar ThomasNing
Browse files

Finished the memory intialization, working on the merging of MSCCLPP

parent 5fb150db
find_library(MSCCLPP_LIBRARY mscclpp HINTS /mscclpp/build)
if(NOT MSCCLPP_LIBRARY)
message(FATAL_ERROR "MSCCLPP library not found in /mscclpp/build.")
endif()
find_path(MSCCLPP_INCLUDE_DIR mscclpp/core.hpp HINTS /mscclpp/include)
if(NOT MSCCLPP_INCLUDE_DIR)
message(FATAL_ERROR "MSCCLPP include directory not found in /mscclpp/include.")
endif()
add_executable(example_cross_gpu_reduce cross_gpu_reduce.cpp)
target_include_directories(example_cross_gpu_reduce PRIVATE ${MSCCLPP_INCLUDE_DIR})
target_link_libraries(example_cross_gpu_reduce PRIVATE ${MSCCLPP_LIBRARY})
\ No newline at end of file
# Cross GPU Reduce Communication
This folder contains example for different GPUs communicate with each other to complete the reduce. It is currently a test operator to verify and exam the communication between two GPUs.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make example_cross_gpu_reduce -j
```
This will result in an executable `build/bin/example_cross_gpu_reduce`
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <string>
#include <thread>
#include <future>
#include <vector>
#include <mscclpp/core.hpp>
#include <mscclpp/gpu_utils.hpp>
#include <mscclpp/sm_channel.hpp>
#include "cross_gpu_reduce.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/cross_gpu_reduce.hpp"
template <class T>
using DeviceHandle = mscclpp::DeviceHandle<T>;
__constant__ DeviceHandle<mscclpp::SmChannel> constSmChannels[8]; // For SmChannel
void setupConnection(int rank, int worldSize, void* data, size_t dataSize){
// Initialize MSCCL++ Communicator
mscclpp::Transport transport = mscclpp::Transport::SmChannel;
// Create the communicator
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, worldSize);
mscclpp::Communicator comm(bootstrap);
// Allocate and register memory
auto localMemory = comm.registerMemory(data, dataSize, transport);
std::vector<mscclpp::RegisteredMemory> remoteMemories;
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connections;
if (rank == 0) {
for(int senderRank = 1; senderRank < worldSize; ++senderRank) {
connections[senderRank] = comm.connectOnSetup(senderRank, 0, mscclpp::Transport::SmChannel);
// Receive memory from sender
remoteMemories.push_back(comm.recvMemoryOnSetup(senderRank, 0));
}
} else {
connections[0] = comm.connectOnSetup(0, 0, mscclpp::Transport::SmChannel);
}
}
template <typename InputType, typename OutputType>
struct AllocateAndTransferFunctor
{
// Invoke the memory transfer between GPUs based on whether it is host gpu or slave gpu.
float invoke_transfer(ck_tile::DeviceMem& transfer_buf,
ck_tile::index_t host_gpu,
int device_id,
const ck_tile::ArgParser& arg_parser,
const ck_tile::stream_config& s,
std::promise<const void*>& host_receive_ptr_promise,
std::future<const void*>& host_receive_ptr_future)
{
ck_tile::index_t M = arg_parser.get_int("M");
ck_tile::index_t N = arg_parser.get_int("N");
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t M_Warp_Tile = 64;
constexpr ck_tile::index_t N_Warp_Tile = 64;
constexpr int kBlockPerCu = 1;
using Vector = ck_tile::sequence<8, 8>;
using ReduceShape = ck_tile::TileReduceShape<ck_tile::sequence<M_Tile, N_Tile>,
ck_tile::sequence<M_Warp, N_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile>,
Vector>;
using ReducePartitioner = ck_tile::CrossReducePartitioner<ReduceShape>;
using CrossReduceReceivePipelinePolicy = ck_tile::ReduceReceivePipelineDefaultPolicy;
using CrossReduceSendPipelinePolicy = ck_tile::ReduceSendPipelineDefaultPolicy;
using CrossReduceReceivePipeline =
ck_tile::CrossReduceReceivePipelineScaleUp<InputType,
OutputType,
ReduceShape,
CrossReduceReceivePipelinePolicy>;
using CrossReduceSendPipeline = ck_tile::
CrossReduceSendPipelineScaleUp<InputType, ReduceShape, CrossReduceSendPipelinePolicy>;
constexpr ck_tile::index_t kBlockSize = CrossReduceReceivePipeline::BlockSize;
transfer_receive_basic_args args_receive;
args_receive.p_reduce = transfer_buf.GetDeviceBuffer();
args_receive.host_gpu = host_gpu;
args_receive.device_id = static_cast<ck_tile::index_t>(device_id);
args_receive.M = M;
args_receive.N = N;
transfer_send_basic_args args_send;
args_send.p_reduce = transfer_buf.GetDeviceBuffer();
args_send.host_gpu = host_gpu;
args_send.device_id = static_cast<ck_tile::index_t>(device_id);
args_send.M = M;
args_send.N = N;
float ave_time = 0.0;
// using MasterKernel = ck_tile::ReduceSendKernel<CrossReduceSendPipeline>;
using SlaveKernel =
ck_tile::ReduceReceiveKernel<ReducePartitioner, CrossReduceReceivePipeline>;
using MasterKernel = ck_tile::ReduceSendKernel<ReducePartitioner, CrossReduceSendPipeline>;
// Depending on whether to enable the receiving kernel or sending kernel
if(static_cast<ck_tile::index_t>(device_id) == host_gpu)
{
// initialize the receive data buffer and global memory location.
ck_tile::HostTensor<InputType> receive_host({M, N});
ck_tile::DeviceMem receive_buf(receive_host.get_element_space_size_in_bytes());
args_receive.p_receive = receive_buf.GetDeviceBuffer();
// initialize the output data buffer.
std::string output_type = arg_parser.get_str("output_type");
if(output_type.compare("float") == 0)
{
ck_tile::HostTensor<OutputType> output_host({M, N});
ck_tile::DeviceMem output_buf(output_host.get_element_space_size_in_bytes());
args_receive.p_output = output_buf.GetDeviceBuffer();
host_receive_ptr_promise.set_value(args_receive.p_receive);
auto kargs_slave = SlaveKernel::MakeKargs(args_receive.p_reduce,
args_receive.p_receive,
args_receive.p_output,
args_receive.M,
args_receive.N);
const dim3 grids_slave = SlaveKernel::GridSize(M, N);
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
SlaveKernel{}, grids_slave, kBlockSize, 0, kargs_slave));
}
else
{
std::cerr << "Currently, we do not support other output data type." << std::endl;
return -1;
}
}
else
{
const void* send_location_ptr = host_receive_ptr_future.get();
args_send.p_send = send_location_ptr;
auto kargs_master = MasterKernel::MakeKargs(
args_send.p_reduce, args_send.p_send, args_send.M, args_send.N);
const dim3 grids_master = MasterKernel::GridSize(M, N);
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
MasterKernel{}, grids_master, kBlockSize, 0, kargs_master));
}
std::string op_name{"Cross GPU Reduce"};
std::cout << "Run" << op_name << "kernel with M =" << M << "N =" << N << " : " << ave_time
<< "ms" << std::endl;
return ave_time;
}
void operator()(int device_id,
ck_tile::HostTensor<InputType>& host_tensor,
ck_tile::DeviceMem& device_mem,
ck_tile::index_t host_gpu,
const ck_tile::ArgParser& arg_parser,
std::promise<const void*>& host_receive_ptr_promise,
std::future<const void*>& host_receive_ptr_future)
{
hipError_t hip_err_set_device = hipSetDevice(device_id);
if(hip_err_set_device != hipSuccess)
{
std::cerr << "Error setting device " << device_id << ": "
<< hipGetErrorString(hip_err_set_device) << std::endl;
return;
}
// Allocate device memory
device_mem.Realloc(host_tensor.get_element_space_size_in_bytes());
// Transfer data to device
device_mem.ToDevice(host_tensor.data());
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
invoke_transfer(device_mem,
host_gpu,
device_id,
arg_parser,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
host_receive_ptr_promise,
host_receive_ptr_future);
}
};
template <typename InputType, typename OutputType>
bool run_cross_gpu_reduce(ck_tile::ArgParser arg_parser)
{
ck_tile::index_t gpu_nums = arg_parser.get_int("gpu_nums");
ck_tile::index_t host_gpu = arg_parser.get_int("host_gpu");
ck_tile::index_t transfer_dim1 = arg_parser.get_int("M");
ck_tile::index_t transfer_dim2 = arg_parser.get_int("N");
// Validate arguments
if(gpu_nums < 1)
{
std::cerr << "Invalid number of GPUs specified." << std::endl;
return -1;
}
// Examine how many gpus inside the server system.
int device_count = 0;
hipError_t hip_err_device_count = hipGetDeviceCount(&device_count);
if(hip_err_device_count != hipSuccess)
{
std::cerr << "Error getting device count: " << hipGetErrorString(hip_err_device_count)
<< std::endl;
return -1;
}
// Make sure the gpus is larger or equals to the required gpu_nums.
if(device_count < gpu_nums)
{
std::cerr << "The available GPUs in the system is less than required. All available GPUs: "
<< device_count << std::endl;
}
if(host_gpu < 0 || host_gpu >= device_count)
{
std::cerr << "Invalid host GPU index specified. Using GPU 0 as host GPU." << std::endl;
host_gpu = 0;
}
// Make sure that we could open each one of the GPU.
// Print device properties
for(int i = 0; i < gpu_nums; ++i)
{
hipDeviceProp_t device_prop;
hipError_t hip_err_device_prop = hipGetDeviceProperties(&device_prop, i);
if(hip_err_device_prop != hipSuccess)
{
std::cerr << "Error getting device properties for device " << i << ": "
<< hipGetErrorString(hip_err_device_prop) << std::endl;
return -1;
}
std::cout << "GPU " << i << ": " << device_prop.name << std::endl;
}
std::vector<int> device_list(gpu_nums);
std::vector<ck_tile::HostTensor<InputType>> transfer_tensor_host_list;
transfer_tensor_host_list.reserve(gpu_nums);
std::vector<ck_tile::DeviceMem> transfer_bufs(gpu_nums);
std::vector<std::thread> threads;
AllocateAndTransferFunctor<InputType, OutputType> allocateAndTransfer;
// Initialize host tensors
for(int i = 0; i < gpu_nums; ++i)
{
device_list[i] = i; // Adjust based on available GPUs
std::vector<int> tensor_dims = {transfer_dim1, transfer_dim2};
transfer_tensor_host_list.emplace_back(tensor_dims);
ck_tile::FillUniformDistribution<InputType>{-5.f, 5.f}(transfer_tensor_host_list.back());
// Enable P2P access between GPUs
if(i != host_gpu)
{
int canAccessPeer = 0;
hipError_t err_peer =
hipDeviceCanAccessPeer(&canAccessPeer, device_list[host_gpu], device_list[i]);
if(err_peer != hipSuccess || !canAccessPeer)
{
std::cerr << "P2P not supported between device " << device_list[host_gpu]
<< " and device " << device_list[i] << std::endl;
return -1; // Handle error accordingly.
}
else
{
// Enable P2P access from host GPU to device i.
hipError_t hip_err_set_device_host = hipSetDevice(device_list[host_gpu]);
if(hip_err_set_device_host != hipSuccess)
{
std::cerr << "Error setting the host device " << host_gpu << ": "
<< hipGetErrorString(hip_err_set_device_host) << std::endl;
return -1;
}
hipError_t err_peer_host = hipDeviceEnablePeerAccess(device_list[i], 0);
if(err_peer_host != hipSuccess && err_peer_host != hipErrorPeerAccessAlreadyEnabled)
{
std::cerr << "Error enabling peer access from host " << device_list[host_gpu]
<< " to device " << device_list[i] << ": "
<< hipGetErrorString(err_peer_host) << std::endl;
return -1;
}
// Enable P2P access from device i to host GPU.
hipError_t hip_err_set_device_send = hipSetDevice(device_list[i]);
if(hip_err_set_device_send != hipSuccess)
{
std::cerr << "Error setting the host device " << host_gpu << ": "
<< hipGetErrorString(hip_err_set_device_send) << std::endl;
return -1;
}
hipError_t err_peer_device = hipDeviceEnablePeerAccess(device_list[host_gpu], 0);
if(err_peer_device != hipSuccess &&
err_peer_device != hipErrorPeerAccessAlreadyEnabled)
{
std::cerr << "Error enabling peer access from device " << device_list[i]
<< " to device " << device_list[host_gpu] << ": "
<< hipGetErrorString(err_peer_device) << std::endl;
return -1;
}
}
}
}
for(int i = 0; i < gpu_nums; ++i)
{
hipError_t hip_device_sync_enable = hipSetDevice(device_list[i]);
if(hip_device_sync_enable != hipSuccess)
{
std::cerr << "Error enable the device for synchronization" << std::endl;
return -1;
}
hipError_t hip_device_sync_err = hipDeviceSynchronize();
if(hip_device_sync_err != hipSuccess)
{
std::cerr << "Error in complete the device for synchronization" << std::endl;
return -1;
}
}
std::promise<const void*> host_receive_ptr_promise;
std::future<const void*> host_receive_ptr_future = host_receive_ptr_promise.get_future();
for(int i = 0; i < gpu_nums; ++i)
{
threads.emplace_back(allocateAndTransfer,
device_list[i],
std::ref(transfer_tensor_host_list[i]),
std::ref(transfer_bufs[i]),
host_gpu,
arg_parser,
std::ref(host_receive_ptr_promise),
std::ref(host_receive_ptr_future));
}
// Wait for all threads to complete
for(auto& t : threads)
{
t.join();
}
bool pass = true;
return !pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string prec = arg_parser.get_str("pr");
bool run_result = true;
if(prec.compare("fp16") == 0)
{
run_result &= run_cross_gpu_reduce<ck_tile::fp16_t, float>(arg_parser);
}
return run_result ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/host.hpp"
struct transfer_receive_basic_args
{
const void* p_reduce;
const void* p_receive;
const void* p_output;
ck_tile::index_t host_gpu;
ck_tile::index_t device_id;
ck_tile::index_t M;
ck_tile::index_t N;
};
struct transfer_send_basic_args
{
const void* p_reduce;
const void* p_send;
ck_tile::index_t host_gpu;
ck_tile::index_t device_id;
ck_tile::index_t M;
ck_tile::index_t N;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("gpu_nums", "2", "number of gpu")
.insert("transfer_size", "2048", "transfer memory size")
.insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
.insert(
"output_type", "float", "output data type. fp16/fp32 (representing 8/16/32 bit data)")
.insert("M", "1024", "transfer memory first dimension")
.insert("N", "2", "transfer memory second dimension")
.insert("op_type", "reduce_add", "Operation type between different GPUs")
.insert("host_gpu", "0", "host gpu #")
.insert("v", "1", "cpu validation or not")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
...@@ -13,3 +13,4 @@ add_subdirectory(10_rmsnorm2d) ...@@ -13,3 +13,4 @@ add_subdirectory(10_rmsnorm2d)
add_subdirectory(11_add_rmsnorm2d_rdquant) add_subdirectory(11_add_rmsnorm2d_rdquant)
add_subdirectory(12_smoothquant) add_subdirectory(12_smoothquant)
add_subdirectory(13_moe_sorting) add_subdirectory(13_moe_sorting)
add_subdirectory(15_cross_gpu_reduce)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_receive_kernel.hpp"
#include "ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_send_kernel.hpp"
#include "ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_reduce_shape.hpp"
#include "ck_tile/ops/cross_gpu_reduce/kernel/cross_gpu_reduce_tile_partitioner.hpp"
#include "ck_tile/ops/cross_gpu_reduce/pipeline/reduce_receive_pipeline_scale_up.hpp"
#include "ck_tile/ops/cross_gpu_reduce/pipeline/reduce_receive_pipeline_default_policy.hpp"
#include "ck_tile/ops/cross_gpu_reduce/pipeline/reduce_send_pipeline_scale_up.hpp"
#include "ck_tile/ops/cross_gpu_reduce/pipeline/reduce_send_pipeline_default_policy.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
template <typename CrossReducePartitioner, typename ReduceReceivePipeline_>
struct ReduceReceiveKernel
{
using ReduceReceivePipeline = remove_cvref_t<ReduceReceivePipeline_>;
static constexpr index_t TransferBlockSize = ReduceReceivePipeline::BlockSize;
using DataType = remove_cvref_t<typename ReduceReceivePipeline::DataType>;
using ODataType = remove_cvref_t<typename ReduceReceivePipeline::ODataType>;
struct ReduceReceiveKargs
{
const void* reduce_ptr;
const void* receive_ptr;
const void* output_ptr;
index_t M;
index_t N;
};
CK_TILE_HOST static constexpr ReduceReceiveKargs MakeKargs(const void* reduce_ptr,
const void* receive_ptr,
const void* output_ptr,
index_t M,
index_t N)
{
return ReduceReceiveKargs{reduce_ptr, receive_ptr, output_ptr, M, N};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return ReduceReceivePipeline::GetSmemSize();
}
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size)
{
return CrossReducePartitioner::GridSize(M_size, N_size);
}
CK_TILE_DEVICE void operator()(ReduceReceiveKargs kargs) const
{
const auto i_M = CrossReducePartitioner{}();
const DataType* reduce_start = static_cast<const DataType*>(kargs.reduce_ptr);
auto transfer_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
reduce_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.N, 1),
number<ReduceReceivePipeline::Vector_N>{},
number<1>{});
}();
auto transfer_block_window =
make_tile_window(transfer_tensor_view,
make_tuple(number<ReduceReceivePipeline::Block_M>{},
number<ReduceReceivePipeline::Block_N>{}),
{i_M, 0});
const ODataType* output_start = static_cast<const ODataType*>(kargs.output_ptr);
auto output_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
output_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.N, 1),
number<ReduceReceivePipeline::Vector_N>{},
number<1>{});
}();
auto output_block_window =
make_tile_window(output_tensor_view,
make_tuple(number<ReduceReceivePipeline::Block_M>{},
number<ReduceReceivePipeline::Block_N>{}),
{i_M, 0});
__shared__ char smem_ptr[ReduceReceivePipeline::GetSmemSize()];
ReduceReceivePipeline{}(transfer_block_window, output_block_window, smem_ptr);
return;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockTile_, typename BlockWarps_, typename WarpTile_, typename Vector>
struct TileReduceShape
{
using BlockTile = remove_cvref_t<BlockTile_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t Block_M = BlockTile::at(number<0>{});
static constexpr index_t Block_N = BlockTile::at(number<1>{});
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
static constexpr index_t Warp_N = WarpTile::at(number<1>{});
static constexpr index_t Vector_M = Vector::at(number<0>{});
static constexpr index_t Vector_N = Vector::at(number<1>{});
static constexpr index_t MPerWarp = Block_M / WarpPerBlock_M;
static constexpr index_t NPerWarp = Block_N / WarpPerBlock_N;
static constexpr index_t ThreadTile_M = MPerWarp / Vector_M;
static constexpr index_t ThreadTile_N = NPerWarp / Vector_N;
static constexpr index_t MThreadPerWarp = MPerWarp / ThreadTile_M;
static constexpr index_t NThreadPerWarp = NPerWarp / ThreadTile_N;
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename CrossReduceShape_>
struct CrossReducePartitioner
{
using CrossReduceShape = remove_cvref_t<CrossReduceShape_>;
static constexpr index_t kM = CrossReduceShape::Block_M;
static constexpr index_t kN = CrossReduceShape::Block_N;
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N)
{
index_t GridDimX = (M + kM - 1) / kM;
index_t GridDimY = (N + kN - 1) / kN;
return dim3(GridDimX, GridDimY, 1);
}
CK_TILE_DEVICE auto operator()() {
const index_t i_M = __builtin_amdgcn_readfirstlane(blockIdx.x * kM);
return i_M;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
template <typename CrossReducePartitioner, typename ReduceSendPipeline_>
struct ReduceSendKernel
{
using ReduceSendPipeline = remove_cvref_t<ReduceSendPipeline_>;
using DataType = remove_cvref_t<typename ReduceSendPipeline::DataType>;
struct ReduceSendKargs
{
const void* reduce_ptr;
const void* send_ptr;
index_t M;
index_t N;
};
CK_TILE_HOST static constexpr ReduceSendKargs
MakeKargs(const void* reduce_ptr, const void* send_ptr, index_t M, index_t N)
{
return ReduceSendKargs{reduce_ptr, send_ptr, M, N};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return ReduceSendPipeline::GetSmemSize();
}
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size)
{
return CrossReducePartitioner::GridSize(M_size, N_size);
}
CK_TILE_DEVICE void operator()(ReduceSendKargs kargs) const
{
const auto i_M = CrossReducePartitioner{}();
const DataType* reduce_start = static_cast<const DataType*>(kargs.reduce_ptr);
auto transfer_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
reduce_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.N, 1),
number<ReduceSendPipeline::Vector_N>{},
number<1>{});
}();
auto transfer_block_window =
make_tile_window(transfer_tensor_view,
make_tuple(number<ReduceSendPipeline::Block_M>{},
number<ReduceSendPipeline::Block_N>{}),
{i_M, 0});
__shared__ char smem_ptr[ReduceSendPipeline::GetSmemSize()];
ReduceSendPipeline{}(transfer_block_window, kargs.send_ptr, smem_ptr);
return;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
struct ReduceReceivePipelineDefaultPolicy
{
template <typename ReduceShape>
CK_TILE_DEVICE static constexpr auto MakeDramTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<ReduceShape::WarpPerBlock_M,
ReduceShape::MThreadPerWarp,
ReduceShape::ThreadTile_M>,
sequence<ReduceShape::WarpPerBlock_N,
ReduceShape::NThreadPerWarp,
ReduceShape::ThreadTile_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>{});
}
template <typename ReduceShape>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = ReduceShape::Block_M;
constexpr index_t kNPerBlock = ReduceShape::Block_N;
constexpr auto lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kNPerBlock), number<32>{});
return lds_block_desc;
}
template <typename DataType, typename ReduceShape>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_host =
sizeof(DataType) *
MakeLdsBlockDescriptor<ReduceShape>().get_element_space_size();
return smem_size_host * 2;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/cross_gpu_reduce/pipeline/reduce_receive_pipeline_default_policy.hpp"
namespace ck_tile {
template <typename DataType_,
typename ODataType_,
typename ReduceShape_,
typename Policy = ReduceReceivePipelineDefaultPolicy>
struct CrossReduceReceivePipelineScaleUp
{
using DataType = remove_cvref_t<DataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using ReduceShape = remove_cvref_t<ReduceShape_>;
static constexpr index_t Block_M = ReduceShape::Block_M;
static constexpr index_t Block_N = ReduceShape::Block_N;
static constexpr index_t Vector_M = ReduceShape::Vector_M;
static constexpr index_t Vector_N = ReduceShape::Vector_N;
static constexpr index_t BlockSize = ReduceShape::NumWarps * get_warp_size();
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{
return integer_divide_ceil(2 * sizeof(DataType) *
Policy::template MakeLdsBlockDescriptor<ReduceShape>()
.get_element_space_size(),
16) *
16;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<DataType, ReduceShape>();
}
template <typename InDramBlockWindowTmp, typename OutDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto operator()(const InDramBlockWindowTmp& input_dram_block_window_tmp,
const OutDramBlockWindowTmp& output_dram_block_window_tmp,
void* p_smem) const
{
DataType* p_lds = static_cast<DataType*>(p_smem);
constexpr auto lds_block_desc = Policy::template MakeLdsBlockDescriptor<ReduceShape>();
auto lds_block = make_tensor_view<address_space_enum::lds>(p_lds, lds_block_desc);
constexpr index_t lds_block_space_size_aligned =
integer_divide_ceil(sizeof(DataType) * lds_block_desc.get_element_space_size(), 16) *
16;
DataType* p_receive_lds = static_cast<DataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + lds_block_space_size_aligned));
// DRAM tile window for load
auto copy_dram_window =
make_tile_window(input_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<Block_M>{}, number<Block_N>{}),
input_dram_block_window_tmp.get_window_origin(),
Policy::template MakeDramTileDistribution<ReduceShape>());
auto copy_lds_window = make_tile_window(lds_block,
make_tuple(number<Block_M>{}, number<Block_N>{}),
{0, 0},
copy_dram_window.get_tile_distribution());
auto host_block_tile = load_tile(copy_dram_window);
const auto block_tile_tmp =
tile_elementwise_in([](const DataType& a) { return a; }, host_block_tile);
store_tile(copy_lds_window, block_tile_tmp);
move_tile_window(copy_lds_window, {0, Block_N});
__syncthreads();
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
struct ReduceSendPipelineDefaultPolicy
{
template <typename ReduceShape>
CK_TILE_DEVICE static constexpr auto MakeDramTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<ReduceShape::WarpPerBlock_M,
ReduceShape::MThreadPerWarp,
ReduceShape::ThreadTile_M>,
sequence<ReduceShape::WarpPerBlock_N,
ReduceShape::NThreadPerWarp,
ReduceShape::ThreadTile_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>{});
}
template <typename ReduceShape>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = ReduceShape::Block_M;
constexpr index_t kNPerBlock = ReduceShape::Block_N;
constexpr auto lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kNPerBlock), number<32>{});
return lds_block_desc;
}
template <typename DataType, typename ReduceShape>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_host =
sizeof(DataType) * MakeLdsBlockDescriptor<ReduceShape>().get_element_space_size();
return smem_size_host;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/cross_gpu_reduce/pipeline/reduce_send_pipeline_default_policy.hpp"
namespace ck_tile {
template <typename DataType_,
typename ReduceShape_,
typename Policy = ReduceSendPipelineDefaultPolicy>
struct CrossReduceSendPipelineScaleUp
{
using DataType = remove_cvref_t<DataType_>;
using ReduceShape = remove_cvref_t<ReduceShape_>;
static constexpr index_t Block_M = ReduceShape::Block_M;
static constexpr index_t Block_N = ReduceShape::Block_N;
static constexpr index_t Vector_N = ReduceShape::Vector_N;
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{
return integer_divide_ceil(sizeof(DataType) *
Policy::template MakeLdsBlockDescriptor<ReduceShape>()
.get_element_space_size(),
16) *
16;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<DataType, ReduceShape>();
}
template <typename InDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto operator()(const InDramBlockWindowTmp& input_dram_block_window_tmp,
const void* p_send,
void* p_smem) const
{
DataType* p_lds = static_cast<DataType*>(p_smem);
constexpr auto lds_block_desc = Policy::template MakeLdsBlockDescriptor<ReduceShape>();
auto lds_block = make_tensor_view<address_space_enum::lds>(p_lds, lds_block_desc);
// DRAM tile window for load
auto copy_dram_window =
make_tile_window(input_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<Block_M>{}, number<Block_N>{}),
input_dram_block_window_tmp.get_window_origin(),
Policy::template MakeDramTileDistribution<ReduceShape>());
auto copy_lds_window = make_tile_window(lds_block,
make_tuple(number<Block_M>{}, number<Block_N>{}),
{0, 0},
copy_dram_window.get_tile_distribution());
auto host_block_tile = load_tile(copy_dram_window);
const auto block_tile_tmp =
tile_elementwise_in([](const DataType& a) { return a; }, host_block_tile);
store_tile(copy_lds_window, block_tile_tmp);
move_tile_window(copy_lds_window, {0, Block_N});
__syncthreads();
// send the data.
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment