Commit d4256471 authored by letaoqin's avatar letaoqin
Browse files

single block

parent 763e26be
......@@ -273,12 +273,12 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t M = 64;
ck::index_t N = 128;
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4;
ck::index_t G1 = 6;
ck::index_t G0 = 1;
ck::index_t G1 = 1;
bool input_permute = false;
bool output_permute = false;
......@@ -395,7 +395,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<Acc0BiasDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
......@@ -404,6 +404,7 @@ int run(int argc, char* argv[])
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
......@@ -414,13 +415,12 @@ int run(int argc, char* argv[])
{
case 0: break;
case 1:
// q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2});
// d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{0});
//d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break;
case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
......@@ -469,7 +469,7 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); //dy[g0,g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
......@@ -481,26 +481,6 @@ int run(int argc, char* argv[])
// = 0
}
Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
Tensor<InputDataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
d_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
// qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(InputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
......@@ -517,7 +497,6 @@ int run(int argc, char* argv[])
q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
......@@ -630,15 +609,39 @@ int run(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
// copy z matirx data form device
z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool pass = true;
if(do_verification)
{
// copy z matirx data form device
Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
Tensor<InputDataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M});
z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
z_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
k_gs_ns_ks.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
d_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
// run fwd again for y, cause z_g_m_n update
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
......
......@@ -118,7 +118,7 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr;
const D0DataType* tmp_p_d0_grid = p_d0_grid;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
......@@ -850,10 +850,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_drop_{p_drop}
{
// TODO: implement bias addition
ignore = p_acc0_biases;
ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides;
......@@ -1042,7 +1039,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
has_main_k_block_loop_,
is_dropout_,
Deterministic>;
std::cout << "device address : " << arg.p_d0_grid_ << std::endl;
return launch_and_time_kernel(
stream_config,
kernel,
......
......@@ -123,9 +123,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto B1K1 = Number<B1K1Value>{};
// D0
static constexpr auto D0M3 = Number<2>{};
static constexpr auto D0M2 = Number<MPerXdl / D0M3.value>{};
static constexpr auto D0M1 = Number<MPerBlock / MPerXdl>{};
static constexpr auto D0M1 = Number<4>{};
static constexpr auto D0M0 = Number<MPerBlock / D0M1.value>{};
// static constexpr auto D0M1 = Number<MPerBlock / MPerXdl>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
......@@ -1160,18 +1160,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(const D0GridDesc_M_N& d0_grid_desc_m_n)
{
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
// const auto M = d0_grid_desc_m_n.GetLength(I0);
// const auto N = d0_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
// const auto MBlock = M / MPerBlock;
// const auto NBlock = N / NPerBlock;
const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 = transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, D0M1, D0M2, D0M3)),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(make_unmerge_transform(make_tuple(D0M0, D0M1)),
make_unmerge_transform(make_tuple(Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 3, 5>{}, Sequence<1, 4>{}));
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return d0_grid_desc_m0_n0_m1_m2_n1_m3;
}
......@@ -1184,28 +1184,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(I1, I1, I1, D0M2, Number<NPerBlock>{}, D0M3),
make_tuple(Number<NPerBlock>{} * D0M3,
Number<NPerBlock>{} * D0M3,
Number<NPerBlock>{} * D0M3,
Number<NPerBlock>{} * D0M3,
D0M3,
I1));
return make_naive_tensor_descriptor(make_tuple(D0M0, Number<NPerBlock>{}, D0M1),
make_tuple(Number<NPerBlock>{} * D0M1, D0M1, I1));
}
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3()
{
constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor(make_tuple(D0M2, Number<NPerBlock>{}, D0M3),
make_tuple(Number<NPerBlock>{} * D0M3, D0M3, I1));
make_naive_tensor_descriptor(make_tuple(D0M0, Number<NPerBlock>{}, D0M1),
make_tuple(Number<NPerBlock>{} * D0M1, D0M1, I1));
constexpr auto d0_n0_n1_m0_m1_m2_m3 = transform_tensor_descriptor(
d0_raw_m0_n_m1,
make_tuple(make_unmerge_transform(make_tuple((D0M2 * D0M3) / I8, I2, I4 / D0M3)),
make_tuple(make_unmerge_transform(make_tuple(D0M0 / I2, I2)),
make_unmerge_transform(
make_tuple(Number<NPerBlock / NPerXdl>{}, Number<NPerXdl>{})),
make_pass_through_transform(D0M3)),
make_pass_through_transform(D0M1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 3, 4>{}, Sequence<0, 1>{}, Sequence<5>{}));
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2_m3;
}
static constexpr auto d0_block_desc_m0_n0_m1_m2_n1_m3 =
......@@ -1214,29 +1208,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3();
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, I4 / D0M3, D0M3));
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I8, I1, D0M1));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, 1, 1, D0M2, NPerBlock, D0M3>, // BlockSliceLengths
Sequence<1, 1, 1, 8, 32, 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
Sequence<D0M0, NPerBlock, D0M1>, // BlockSliceLengths
Sequence<8, 32, 1>, // ThreadClusterLengths
Sequence<0, 2, 1>, // ThreadClusterArrangeOrder
D0DataType, // SrcData
D0DataType, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim
Sequence<0, 2, 1>, // SrcDimAccessOrder
Sequence<1, 0, 2>, // DstDimAccessOrder
1, // SrcVectorDim
2, // DstVectorDim
NPerBlock / 32, // SrcScalarPerVector
D0M3.value / 1, // DstScalarPerVector
1, // SrcScalarPerVector
1, // DstScalarPerVector
1,
1,
false,
true,
true, // DstResetCoord
1>;
......@@ -1245,11 +1239,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
D0DataType, // DstData
decltype(d0_block_desc_n0_n1_m0_m1_m2_m3), // SrcDesc
decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 2, 2>, // SliceLengths
Sequence<0, 1, 2, 3, 4, 5>, // DimAccessOrder
5, // SrcVectorDim
D0M3.value, // SrcScalarPerVector
D0M3.value>;
Sequence<1, 1, 8, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim
1, // SrcScalarPerVector
1>;
};
template <bool HasMainKBlockLoop,
......@@ -1739,13 +1733,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// D0
auto d0_block_copy_global_to_lds =
typename D0Loader::D0BlockwiseCopy(d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, block_work_idx_n, 0, 0, 0, 0),
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0, 0));
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
ignore = d0_thread_copy_lds_to_vgpr;
//
// set up Y dot dY
......@@ -1922,6 +1916,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
s_slash_p_thread_buf,
num_k_block_main_loop);
// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// {
// auto p_lds = static_cast<D0DataType*>(p_shared);
// for(int idx = 0; idx < 32; idx++)
// {
// float tmp_gbl = ck::type_convert<float>(p_d0_grid[idx]);
// float tmp_lds = ck::type_convert<float>(p_lds[idx]);
// block_sync_lds();
// printf("p_d0_grid: %p, index: %d, gbl[idx]: %f, lds[idx]: %f \n",
// ck::type_convert<const void*>(p_d0_grid),
// idx,
// tmp_gbl,
// tmp_lds);
// }
// }
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
......@@ -1983,61 +1993,65 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// add bias
if constexpr(!is_same<D0DataType, void>::value)
{
static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<16>{}));
// static constexpr auto c_thread_desc_ =
// make_naive_tensor_descriptor_packed(make_tuple(I2, Number<16>{}));
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
static_cast<D0DataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Loader::d0_thread_desc_.GetElementSpaceSize());
ignore = d0_thread_buf;
static_for<0, D0M1, 1>{}([&](auto mr) {
// load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf);
// load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
// d0_block_copy_global_to_lds.MoveSrcSliceWindow(
// d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_desc_n0_n1_m0_m1_m2_m3,
make_tuple(I0, I0, I0, I0, I0, I0),
d0_block_buf,
D0Loader::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add
static_for<0, D0Loader::d0_thread_desc_.GetElementSpaceSize(), 1>{}(
[&](auto i) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(mr, i));
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// {
// printf("c_offset: %d s_slash_p_thread_buf(Number<c_offset>{}):
// %f, "
// "d0_thread_buf[i]: %f\n",
// c_offset,
// s_slash_p_thread_buf(Number<c_offset>{}),
// ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]));
// }
});
d0_block_copy_global_to_lds.RunWrite(D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_desc_n0_n1_m0_m1_m2_m3,
make_tuple(I0, I0, I0, I0, I0),
d0_block_buf,
D0Loader::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
// static_for<0, 2, 1>{}([&](auto mr) {
// bias add
static_for<0, s_slash_p_thread_buf.Size(), 1>{}([&](auto i) {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(mr, i));
//if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// if(ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]) != 1.0f)
// {
// float tmp_lds =
// ck::type_convert<float>(static_cast<D0DataType*>(p_shared)[i.value]);
// float tmp_gbl = ck::type_convert<float>(p_d0_grid[i.value]);
// block_sync_lds();
// printf("id: %d : i: %d, gbl[i]: %f "
// ",lds[i]: %f d0_thread_buf[i]: %f\n",
// get_thread_local_1d_id(),
// i.value,
// tmp_gbl,
// tmp_lds,
// ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]));
// }
s_slash_p_thread_buf(i) += ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
//});
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(1, 0, -D0M1.value, 0, 0, 0));
// d0_block_copy_global_to_lds.MoveSrcSliceWindow(
// d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(1, 0, -D0M1.value, 0, 0,
// 0));
}
// P_i: = softmax(scalar * S_i:)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment