Commit d8154515 authored by ltqin's avatar ltqin
Browse files

code regular

parent a085b740
...@@ -51,7 +51,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl ...@@ -51,7 +51,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 32, 32, 4, 8, 16, 16, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
// clang-format on // clang-format on
using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
...@@ -92,8 +92,8 @@ int main(int argc, char* argv[]) ...@@ -92,8 +92,8 @@ int main(int argc, char* argv[])
bool time_kernel = false; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 32; ck::index_t M = 16;
ck::index_t N = 32; ck::index_t N = 16;
ck::index_t K = 64; ck::index_t K = 64;
ck::index_t StrideA = K; ck::index_t StrideA = K;
......
...@@ -36,7 +36,6 @@ struct BlockwiseSoftmax_V1 ...@@ -36,7 +36,6 @@ struct BlockwiseSoftmax_V1
template <typename TopIdx> template <typename TopIdx>
__host__ __device__ static constexpr auto CalculateBottomIndex(const TopIdx& idx_top) __host__ __device__ static constexpr auto CalculateBottomIndex(const TopIdx& idx_top)
{ {
const auto index = idx_top[I0]; const auto index = idx_top[I0];
const auto m = (index / WaveSize) * MPerXDL + index % MPerXDL; const auto m = (index / WaveSize) * MPerXDL + index % MPerXDL;
const auto k = (index % WaveSize) / MPerXDL; const auto k = (index % WaveSize) / MPerXDL;
...@@ -101,12 +100,12 @@ struct BlockwiseSoftmax_V1 ...@@ -101,12 +100,12 @@ struct BlockwiseSoftmax_V1
auto softmax_lds_buffer = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto softmax_lds_buffer = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<AccDataType*>(p_softmax), MPerBlock * 2); static_cast<AccDataType*>(p_softmax), MPerBlock * 2);
// static auto lds_buffer_m_k = GetSpaceForPreMax(); // thread id map to thread layout
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const auto thread_cluster_idx = const auto thread_cluster_idx =
BlockToMKMap_M0_K_M1Adapt::CalculateBottomIndex(make_multi_index(thread_local_id)); BlockToMKMap_M0_K_M1Adapt::CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
// const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
// //
// find max value // find max value
// //
...@@ -123,17 +122,15 @@ struct BlockwiseSoftmax_V1 ...@@ -123,17 +122,15 @@ struct BlockwiseSoftmax_V1
ThreadwiseMaxReduce::Reduce(xdlops_out.template AsType<float>(), max_value_buf); ThreadwiseMaxReduce::Reduce(xdlops_out.template AsType<float>(), max_value_buf);
}); });
//{const index_t thread_local_id = get_thread_local_1d_id(); // block reduce for max
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
// ignore = p_reduce_work_buffer;}
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0)); BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0));
block_sync_lds(); block_sync_lds();
// save max value // save max value
softmax_lds_buffer(softmax_buf_desc_m_k.CalculateOffset( if(0 == thread_k_cluster_id)
make_tuple(thread_m_cluster_id, 1))) = max_value_buf(I0); {
softmax_lds_buffer(softmax_buf_desc_m_k.CalculateOffset(
printf("thread id: %d, Max: %f\t\t", thread_local_id, max_value_buf[I0]); make_tuple(thread_m_cluster_id, 1))) = max_value_buf(I0);
}
// //
// softmax // softmax
...@@ -143,7 +140,7 @@ struct BlockwiseSoftmax_V1 ...@@ -143,7 +140,7 @@ struct BlockwiseSoftmax_V1
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>(); accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
}); });
// calculate exp for elements // calculate exp for elements, P=exp(s-max)
static_for<0, NRepeat, 1>{}([&](auto n) { static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t in_offset = in_thread_desc.CalculateOffset(make_tuple(0, n, 0)); constexpr index_t in_offset = in_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto& xdlops_out = in_thread_buf.GetVectorTypeReference(Number<in_offset>{}); auto& xdlops_out = in_thread_buf.GetVectorTypeReference(Number<in_offset>{});
...@@ -164,19 +161,11 @@ struct BlockwiseSoftmax_V1 ...@@ -164,19 +161,11 @@ struct BlockwiseSoftmax_V1
block_sync_lds(); block_sync_lds();
// save sum // save sum
softmax_lds_buffer(softmax_buf_desc_m_k.CalculateOffset( if(0 == thread_k_cluster_id)
make_tuple(thread_m_cluster_id, 0))) = accu_value_buf(I0); {
// change elements softmax_lds_buffer(softmax_buf_desc_m_k.CalculateOffset(
/* static_for<0, NRepeat, 1>{}([&](auto n) { make_tuple(thread_m_cluster_id, 0))) = accu_value_buf(I0);
constexpr index_t in_offset = in_thread_desc.CalculateOffset(make_tuple(0, n, 0)); }
auto& xdlops_out =
in_thread_buf.GetVectorTypeReference(Number<in_offset>{});
static_for<0, RegSizePerXdlops, 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
xdlops_out.template AsType<float>()[iK] / accu_value_buf(I0);
});
});*/
} }
}; // namespace ck }; // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment