Commit 3fcad951 authored by ThomasNing's avatar ThomasNing
Browse files

finished on the receive part.

parent c0439e64
......@@ -11,7 +11,7 @@ struct transfer_receive_basic_args
{
const void* p_reduce;
std::array<const void*, MaxSendGPUNum> p_receive_list;
const void* p_output;
void* p_output;
ck_tile::index_t host_gpu;
ck_tile::index_t device_id;
ck_tile::index_t M;
......
......@@ -21,7 +21,7 @@ struct ReduceReceiveKernel
{
const void* reduce_ptr;
std::array<const void*, MaxSendGPUNum> receive_ptr_list;
const void* output_ptr;
void* output_ptr;
index_t M;
index_t N;
};
......@@ -29,7 +29,7 @@ struct ReduceReceiveKernel
CK_TILE_HOST static constexpr ReduceReceiveKargs
MakeKargs(const void* reduce_ptr,
std::array<const void*, MaxSendGPUNum> receive_ptr_list,
const void* output_ptr,
void* output_ptr,
index_t M,
index_t N)
{
......@@ -91,7 +91,7 @@ struct ReduceReceiveKernel
number<ReduceReceivePipeline::Block_N>{}),
{i_m, i_n});
const ODataType* output_start = static_cast<const ODataType*>(kargs.output_ptr);
ODataType* output_start = static_cast<ODataType*>(kargs.output_ptr);
auto output_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
output_start,
......@@ -106,10 +106,8 @@ struct ReduceReceiveKernel
number<ReduceReceivePipeline::Block_N>{}),
{i_m, i_n});
__shared__ char smem_ptr[ReduceReceivePipeline::GetSmemSize()];
ReduceReceivePipeline{}(
transfer_block_window, receive_block_window, output_block_window, smem_ptr);
transfer_block_window, receive_block_window, output_block_window);
return;
}
};
......
......@@ -46,16 +46,8 @@ struct CrossReduceReceivePipelineScaleUp
CK_TILE_HOST_DEVICE auto
operator()(const InDramBlockWindowTmp& input_dram_block_window_tmp,
const ReceiveDramBlockWindowTmp& receive_dram_block_window_tmp,
const OutDramBlockWindowTmp& output_dram_block_window_tmp,
void* p_smem) const
OutDramBlockWindowTmp& output_dram_block_window_tmp) const
{
DataType* p_lds = static_cast<DataType*>(p_smem);
constexpr auto lds_block_desc = Policy::template MakeLdsBlockDescriptor<ReduceShape>();
auto lds_block = make_tensor_view<address_space_enum::lds>(p_lds, lds_block_desc);
constexpr index_t lds_block_space_size_aligned =
integer_divide_ceil(sizeof(DataType) * lds_block_desc.get_element_space_size(), 16) *
16;
// DRAM tile window for load
auto copy_dram_window =
make_tile_window(input_dram_block_window_tmp.get_bottom_tensor_view(),
......@@ -63,40 +55,25 @@ struct CrossReduceReceivePipelineScaleUp
input_dram_block_window_tmp.get_window_origin(),
Policy::template MakeDramTileDistribution<ReduceShape>());
auto copy_lds_window = make_tile_window(lds_block,
make_tuple(number<Block_M>{}, number<Block_N>{}),
{0, 0},
copy_dram_window.get_tile_distribution());
auto host_block_tile = load_tile(copy_dram_window);
// Receive tile window initialization
DataType* p_receive_lds = static_cast<DataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + lds_block_space_size_aligned));
auto receive_dram_window =
make_tile_window(receive_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<Block_M>{}, number<Block_N>{}),
receive_dram_block_window_tmp.get_window_origin(),
Policy::template MakeDramTileDistribution<ReduceShape>());
auto receive_lds_block =
make_tensor_view<address_space_enum::lds>(p_receive_lds, lds_block_desc);
auto receive_lds_window = make_tile_window(receive_lds_block,
make_tuple(number<Block_M>{}, number<Block_N>{}),
{0, 0},
receive_dram_window.get_tile_distribution());
auto receive_block_tile = load_tile(receive_dram_window);
const auto host_block_tile_tmp =
tile_elementwise_in([](const DataType& a) { return a; }, host_block_tile);
store_tile(copy_lds_window, host_block_tile_tmp);
const auto receive_block_tile_tmp =
tile_elementwise_in([](const DataType& a) { return a; }, receive_block_tile);
store_tile(receive_lds_window, receive_block_tile_tmp);
auto acc = cast_tile<ODataType>(host_block_tile);
__syncthreads();
sweep_tile(receive_block_tile, [&](auto idx) {
acc(idx) =type_convert<DataType>(receive_block_tile(idx)) + acc(idx);
});
store_tile(const_cast<OutDramBlockWindowTmp&>(output_dram_block_window_tmp), acc);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment