Commit c0252636 authored by ltqin's avatar ltqin
Browse files

regular code

parent 787626fb
...@@ -76,17 +76,17 @@ struct BlockwiseSoftmax_V1 ...@@ -76,17 +76,17 @@ struct BlockwiseSoftmax_V1
{ {
// printf("c_thread_desc: {%d, %d, %d}", c_thread_desc.GetLength(I0).value, // printf("c_thread_desc: {%d, %d, %d}", c_thread_desc.GetLength(I0).value,
// c_thread_desc.GetLength(I1).value, c_thread_desc.GetLength(I2).value); // c_thread_desc.GetLength(I1).value, c_thread_desc.GetLength(I2).value);
auto p_reduce_work_buffer = static_cast<AccDataType*>(p_shared); auto reduce_work_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<AccDataType*>(p_shared), BlockSize);
//
// find max value
//
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> max_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> max_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>(); max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
}); });
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
// max value for one thread // max value for one thread
static_for<0, NRepeat, 1>{}([&](auto n) { static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0)); constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
...@@ -99,47 +99,50 @@ struct BlockwiseSoftmax_V1 ...@@ -99,47 +99,50 @@ struct BlockwiseSoftmax_V1
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]); // printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
// ignore = p_reduce_work_buffer;} // ignore = p_reduce_work_buffer;}
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0)); BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0));
block_sync_lds(); block_sync_lds();
// {const index_t thread_local_id = get_thread_local_1d_id(); // {const index_t thread_local_id = get_thread_local_1d_id();
// printf("thread id: %d, Max: %f\t\t", thread_local_id, max_value_buf[I0]);} // printf("thread id: %d, Max: %f\t\t", thread_local_id, max_value_buf[I0]);}
//
// softmax // softmax
{ //
// calculate exp for elements
static_for<0, NRepeat, 1>{}([&](auto n) { StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0)); static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{}); accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
static_for<0, RegSizePerXdlops, 1>{}([&](auto iK) { // calculate exp for elements
xdlops_out.template AsType<float>()(iK) = static_for<0, NRepeat, 1>{}([&](auto n) {
math::exp(xdlops_out.template AsType<float>()[iK] - max_value_buf(I0)); constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
}); auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
});
// sum data static_for<0, RegSizePerXdlops, 1>{}([&](auto iK) {
static_for<0, NRepeat, 1>{}([&](auto n) { xdlops_out.template AsType<float>()(iK) =
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0)); math::exp(xdlops_out.template AsType<float>()[iK] - max_value_buf(I0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
ThreadwiseSumReduce::Reduce(xdlops_out.template AsType<float>(), accu_value_buf);
block_sync_lds();
}); });
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0)); });
// sum data
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
ThreadwiseSumReduce::Reduce(xdlops_out.template AsType<float>(), accu_value_buf);
block_sync_lds(); block_sync_lds();
});
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0));
block_sync_lds();
// change elements // change elements
static_for<0, NRepeat, 1>{}([&](auto n) { static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0)); constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{}); auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
static_for<0, c_thread_desc.GetLength(I2), 1>{}([&](auto iK) { static_for<0, c_thread_desc.GetLength(I2), 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) = xdlops_out.template AsType<float>()(iK) =
xdlops_out.template AsType<float>()[iK] / accu_value_buf(I0); xdlops_out.template AsType<float>()[iK] / accu_value_buf(I0);
});
}); });
} });
} }
}; // namespace ck }; // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment