Commit c0252636 authored by ltqin's avatar ltqin
Browse files

regular code

parent 787626fb
......@@ -76,17 +76,17 @@ struct BlockwiseSoftmax_V1
{
// printf("c_thread_desc: {%d, %d, %d}", c_thread_desc.GetLength(I0).value,
// c_thread_desc.GetLength(I1).value, c_thread_desc.GetLength(I2).value);
auto p_reduce_work_buffer = static_cast<AccDataType*>(p_shared);
auto reduce_work_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<AccDataType*>(p_shared), BlockSize);
//
// find max value
//
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> max_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
});
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
// max value for one thread
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
......@@ -99,47 +99,50 @@ struct BlockwiseSoftmax_V1
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
// ignore = p_reduce_work_buffer;}
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0));
block_sync_lds();
// {const index_t thread_local_id = get_thread_local_1d_id();
// printf("thread id: %d, Max: %f\t\t", thread_local_id, max_value_buf[I0]);}
//
// softmax
{
// calculate exp for elements
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
static_for<0, RegSizePerXdlops, 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
math::exp(xdlops_out.template AsType<float>()[iK] - max_value_buf(I0));
});
});
// sum data
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
ThreadwiseSumReduce::Reduce(xdlops_out.template AsType<float>(), accu_value_buf);
block_sync_lds();
//
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
// calculate exp for elements
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
static_for<0, RegSizePerXdlops, 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
math::exp(xdlops_out.template AsType<float>()[iK] - max_value_buf(I0));
});
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0));
});
// sum data
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
ThreadwiseSumReduce::Reduce(xdlops_out.template AsType<float>(), accu_value_buf);
block_sync_lds();
});
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0));
block_sync_lds();
// change elements
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
// change elements
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t c_offset = c_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = c_thread_buf.GetVectorTypeReference(Number<c_offset>{});
static_for<0, c_thread_desc.GetLength(I2), 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
xdlops_out.template AsType<float>()[iK] / accu_value_buf(I0);
});
static_for<0, c_thread_desc.GetLength(I2), 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
xdlops_out.template AsType<float>()[iK] / accu_value_buf(I0);
});
}
});
}
}; // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment