Commit 3fcad951 authored by ThomasNing's avatar ThomasNing
Browse files

finished on the receive part.

parent c0439e64
...@@ -11,7 +11,7 @@ struct transfer_receive_basic_args ...@@ -11,7 +11,7 @@ struct transfer_receive_basic_args
{ {
const void* p_reduce; const void* p_reduce;
std::array<const void*, MaxSendGPUNum> p_receive_list; std::array<const void*, MaxSendGPUNum> p_receive_list;
const void* p_output; void* p_output;
ck_tile::index_t host_gpu; ck_tile::index_t host_gpu;
ck_tile::index_t device_id; ck_tile::index_t device_id;
ck_tile::index_t M; ck_tile::index_t M;
......
...@@ -21,7 +21,7 @@ struct ReduceReceiveKernel ...@@ -21,7 +21,7 @@ struct ReduceReceiveKernel
{ {
const void* reduce_ptr; const void* reduce_ptr;
std::array<const void*, MaxSendGPUNum> receive_ptr_list; std::array<const void*, MaxSendGPUNum> receive_ptr_list;
const void* output_ptr; void* output_ptr;
index_t M; index_t M;
index_t N; index_t N;
}; };
...@@ -29,7 +29,7 @@ struct ReduceReceiveKernel ...@@ -29,7 +29,7 @@ struct ReduceReceiveKernel
CK_TILE_HOST static constexpr ReduceReceiveKargs CK_TILE_HOST static constexpr ReduceReceiveKargs
MakeKargs(const void* reduce_ptr, MakeKargs(const void* reduce_ptr,
std::array<const void*, MaxSendGPUNum> receive_ptr_list, std::array<const void*, MaxSendGPUNum> receive_ptr_list,
const void* output_ptr, void* output_ptr,
index_t M, index_t M,
index_t N) index_t N)
{ {
...@@ -91,7 +91,7 @@ struct ReduceReceiveKernel ...@@ -91,7 +91,7 @@ struct ReduceReceiveKernel
number<ReduceReceivePipeline::Block_N>{}), number<ReduceReceivePipeline::Block_N>{}),
{i_m, i_n}); {i_m, i_n});
const ODataType* output_start = static_cast<const ODataType*>(kargs.output_ptr); ODataType* output_start = static_cast<ODataType*>(kargs.output_ptr);
auto output_tensor_view = [&]() { auto output_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>( return make_naive_tensor_view<address_space_enum::global>(
output_start, output_start,
...@@ -106,10 +106,8 @@ struct ReduceReceiveKernel ...@@ -106,10 +106,8 @@ struct ReduceReceiveKernel
number<ReduceReceivePipeline::Block_N>{}), number<ReduceReceivePipeline::Block_N>{}),
{i_m, i_n}); {i_m, i_n});
__shared__ char smem_ptr[ReduceReceivePipeline::GetSmemSize()];
ReduceReceivePipeline{}( ReduceReceivePipeline{}(
transfer_block_window, receive_block_window, output_block_window, smem_ptr); transfer_block_window, receive_block_window, output_block_window);
return; return;
} }
}; };
......
...@@ -46,16 +46,8 @@ struct CrossReduceReceivePipelineScaleUp ...@@ -46,16 +46,8 @@ struct CrossReduceReceivePipelineScaleUp
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const InDramBlockWindowTmp& input_dram_block_window_tmp, operator()(const InDramBlockWindowTmp& input_dram_block_window_tmp,
const ReceiveDramBlockWindowTmp& receive_dram_block_window_tmp, const ReceiveDramBlockWindowTmp& receive_dram_block_window_tmp,
const OutDramBlockWindowTmp& output_dram_block_window_tmp, OutDramBlockWindowTmp& output_dram_block_window_tmp) const
void* p_smem) const
{ {
DataType* p_lds = static_cast<DataType*>(p_smem);
constexpr auto lds_block_desc = Policy::template MakeLdsBlockDescriptor<ReduceShape>();
auto lds_block = make_tensor_view<address_space_enum::lds>(p_lds, lds_block_desc);
constexpr index_t lds_block_space_size_aligned =
integer_divide_ceil(sizeof(DataType) * lds_block_desc.get_element_space_size(), 16) *
16;
// DRAM tile window for load // DRAM tile window for load
auto copy_dram_window = auto copy_dram_window =
make_tile_window(input_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(input_dram_block_window_tmp.get_bottom_tensor_view(),
...@@ -63,40 +55,25 @@ struct CrossReduceReceivePipelineScaleUp ...@@ -63,40 +55,25 @@ struct CrossReduceReceivePipelineScaleUp
input_dram_block_window_tmp.get_window_origin(), input_dram_block_window_tmp.get_window_origin(),
Policy::template MakeDramTileDistribution<ReduceShape>()); Policy::template MakeDramTileDistribution<ReduceShape>());
auto copy_lds_window = make_tile_window(lds_block,
make_tuple(number<Block_M>{}, number<Block_N>{}),
{0, 0},
copy_dram_window.get_tile_distribution());
auto host_block_tile = load_tile(copy_dram_window); auto host_block_tile = load_tile(copy_dram_window);
// Receive tile window initialization
DataType* p_receive_lds = static_cast<DataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + lds_block_space_size_aligned));
auto receive_dram_window = auto receive_dram_window =
make_tile_window(receive_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(receive_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<Block_M>{}, number<Block_N>{}), make_tuple(number<Block_M>{}, number<Block_N>{}),
receive_dram_block_window_tmp.get_window_origin(), receive_dram_block_window_tmp.get_window_origin(),
Policy::template MakeDramTileDistribution<ReduceShape>()); Policy::template MakeDramTileDistribution<ReduceShape>());
auto receive_lds_block =
make_tensor_view<address_space_enum::lds>(p_receive_lds, lds_block_desc);
auto receive_lds_window = make_tile_window(receive_lds_block,
make_tuple(number<Block_M>{}, number<Block_N>{}),
{0, 0},
receive_dram_window.get_tile_distribution());
auto receive_block_tile = load_tile(receive_dram_window); auto receive_block_tile = load_tile(receive_dram_window);
const auto host_block_tile_tmp = auto acc = cast_tile<ODataType>(host_block_tile);
tile_elementwise_in([](const DataType& a) { return a; }, host_block_tile);
store_tile(copy_lds_window, host_block_tile_tmp);
const auto receive_block_tile_tmp =
tile_elementwise_in([](const DataType& a) { return a; }, receive_block_tile);
store_tile(receive_lds_window, receive_block_tile_tmp);
__syncthreads(); __syncthreads();
sweep_tile(receive_block_tile, [&](auto idx) {
acc(idx) =type_convert<DataType>(receive_block_tile(idx)) + acc(idx);
});
store_tile(const_cast<OutDramBlockWindowTmp&>(output_dram_block_window_tmp), acc);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment