"tests/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "1c6645825c99ff474be2e199018f48ade9b2a763"
Commit c8b4ac22 authored by rocking's avatar rocking
Browse files

Add global write

parent a760a732
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
namespace ck { namespace ck {
...@@ -48,6 +49,7 @@ struct GridwiseElementwise_2D ...@@ -48,6 +49,7 @@ struct GridwiseElementwise_2D
static constexpr auto thread_buf_desc_M_N = make_naive_tensor_descriptor_packed( static constexpr auto thread_buf_desc_M_N = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadTileSize>{}, Number<NThreadTileSize>{})); make_tuple(Number<MThreadTileSize>{}, Number<NThreadTileSize>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
using ThreadBufDesc_M_N = decltype(thread_buf_desc_M_N); using ThreadBufDesc_M_N = decltype(thread_buf_desc_M_N);
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -88,7 +90,7 @@ struct GridwiseElementwise_2D ...@@ -88,7 +90,7 @@ struct GridwiseElementwise_2D
p_a_global, a_grid_desc_m_n.GetElementSpaceSize()); p_a_global, a_grid_desc_m_n.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m_n.GetElementSpaceSize()); p_b_global, b_grid_desc_m_n.GetElementSpaceSize());
const auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m_n.GetElementSpaceSize()); p_c_global, c_grid_desc_m_n.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ADataType, MThreadTileSize * NThreadTileSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, ADataType, MThreadTileSize * NThreadTileSize, true>
...@@ -141,10 +143,23 @@ struct GridwiseElementwise_2D ...@@ -141,10 +143,23 @@ struct GridwiseElementwise_2D
}); });
// TODO - global write // TODO - global write
(void)c_global_buf; const auto c_global_write_offset = CalculateElementwiseIndex(c_grid_desc_m_n);
// c_global_write.Run( auto c_global_write = ThreadwiseTensorSliceTransfer_v1r3<
// thread_buf_desc_M_N, c_thread_buf, c_grid_desc_m_n, make_tuple(I0, I0), CDataType,
// c_global_buf); CDataType,
decltype(thread_buf_desc_M_N),
GridDesc_M_N,
PassThrough,
Sequence<MThreadTileSize, NThreadTileSize>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // DstVectorDim
CThreadTransferSrcScalarPerVector, // DstScalarPerVector
InMemoryDataOperationEnum::Set, // DstInMemOp
1, // DstScalarStrideInVector
false>{c_grid_desc_m_n, c_global_write_offset, PassThrough{}};
c_global_write.Run(
thread_buf_desc_M_N, make_tuple(I0, I0), c_thread_buf, c_grid_desc_m_n, c_global_buf);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment