"tests/test_models/vscode:/vscode.git/clone" did not exist on "e013bab5674e8d35d1998a050e1fa239ac9a747d"
Commit d4256471 authored by letaoqin's avatar letaoqin
Browse files

single block

parent 763e26be
...@@ -273,12 +273,12 @@ int run(int argc, char* argv[]) ...@@ -273,12 +273,12 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512; ck::index_t M = 64;
ck::index_t N = 512; ck::index_t N = 128;
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; ck::index_t G0 = 1;
ck::index_t G1 = 6; ck::index_t G1 = 1;
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
...@@ -395,7 +395,7 @@ int run(int argc, char* argv[]) ...@@ -395,7 +395,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); Tensor<Acc0BiasDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
...@@ -404,6 +404,7 @@ int run(int argc, char* argv[]) ...@@ -404,6 +404,7 @@ int run(int argc, char* argv[])
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl; std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
...@@ -414,13 +415,12 @@ int run(int argc, char* argv[]) ...@@ -414,13 +415,12 @@ int run(int argc, char* argv[])
{ {
case 0: break; case 0: break;
case 1: case 1:
// q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2});
// d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{0}); //d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break; break;
case 2: case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
...@@ -469,7 +469,7 @@ int run(int argc, char* argv[]) ...@@ -469,7 +469,7 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); //dy[g0,g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
...@@ -481,26 +481,6 @@ int run(int argc, char* argv[]) ...@@ -481,26 +481,6 @@ int run(int argc, char* argv[])
// = 0 // = 0
} }
Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
Tensor<InputDataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
d_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(InputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem k_device_buf(sizeof(InputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
...@@ -517,7 +497,6 @@ int run(int argc, char* argv[]) ...@@ -517,7 +497,6 @@ int run(int argc, char* argv[])
q_device_buf.ToDevice(q_gs_ms_ks.mData.data()); q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data()); v_device_buf.ToDevice(v_gs_os_ns.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data()); ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
...@@ -630,15 +609,39 @@ int run(int argc, char* argv[]) ...@@ -630,15 +609,39 @@ int run(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
// copy z matirx data form device
z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
// std::cout << "z_g_m_n ref:\n" << z_g_m_n; // std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
// copy z matirx data form device
Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
Tensor<InputDataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M});
z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
z_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
k_gs_ns_ks.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
d_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
// run fwd again for y, cause z_g_m_n update // run fwd again for y, cause z_g_m_n update
run_attention_fwd_host(q_g_m_k, run_attention_fwd_host(q_g_m_k,
k_g_n_k, k_g_n_k,
......
...@@ -118,7 +118,7 @@ __global__ void ...@@ -118,7 +118,7 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = p_d0_grid;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
...@@ -850,10 +850,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -850,10 +850,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_drop_{p_drop} p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
ignore = p_acc0_biases;
ignore = p_acc1_biases; ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths; ignore = acc1_biases_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides; ignore = acc1_biases_gs_ms_gemm1ns_strides;
...@@ -1042,7 +1039,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1042,7 +1039,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_, is_dropout_,
Deterministic>; Deterministic>;
std::cout << "device address : " << arg.p_d0_grid_ << std::endl;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
kernel, kernel,
......
...@@ -123,9 +123,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -123,9 +123,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto B1K1 = Number<B1K1Value>{}; static constexpr auto B1K1 = Number<B1K1Value>{};
// D0 // D0
static constexpr auto D0M3 = Number<2>{}; static constexpr auto D0M1 = Number<4>{};
static constexpr auto D0M2 = Number<MPerXdl / D0M3.value>{}; static constexpr auto D0M0 = Number<MPerBlock / D0M1.value>{};
static constexpr auto D0M1 = Number<MPerBlock / MPerXdl>{}; // static constexpr auto D0M1 = Number<MPerBlock / MPerXdl>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2 static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
...@@ -1160,18 +1160,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1160,18 +1160,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(const D0GridDesc_M_N& d0_grid_desc_m_n) MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(const D0GridDesc_M_N& d0_grid_desc_m_n)
{ {
const auto M = d0_grid_desc_m_n.GetLength(I0); // const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1); // const auto N = d0_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock; // const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock; // const auto NBlock = N / NPerBlock;
const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 = transform_tensor_descriptor( const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 = transform_tensor_descriptor(
d0_grid_desc_m_n, d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, D0M1, D0M2, D0M3)), make_tuple(make_unmerge_transform(make_tuple(D0M0, D0M1)),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))), make_unmerge_transform(make_tuple(Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 3, 5>{}, Sequence<1, 4>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return d0_grid_desc_m0_n0_m1_m2_n1_m3; return d0_grid_desc_m0_n0_m1_m2_n1_m3;
} }
...@@ -1184,28 +1184,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1184,28 +1184,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3() __host__ __device__ static constexpr auto GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3()
{ {
// B1 matrix in LDS memory, dst of blockwise copy // B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(make_tuple(D0M0, Number<NPerBlock>{}, D0M1),
make_tuple(I1, I1, I1, D0M2, Number<NPerBlock>{}, D0M3), make_tuple(Number<NPerBlock>{} * D0M1, D0M1, I1));
make_tuple(Number<NPerBlock>{} * D0M3,
Number<NPerBlock>{} * D0M3,
Number<NPerBlock>{} * D0M3,
Number<NPerBlock>{} * D0M3,
D0M3,
I1));
} }
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3() __host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3()
{ {
constexpr auto d0_raw_m0_n_m1 = constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor(make_tuple(D0M2, Number<NPerBlock>{}, D0M3), make_naive_tensor_descriptor(make_tuple(D0M0, Number<NPerBlock>{}, D0M1),
make_tuple(Number<NPerBlock>{} * D0M3, D0M3, I1)); make_tuple(Number<NPerBlock>{} * D0M1, D0M1, I1));
constexpr auto d0_n0_n1_m0_m1_m2_m3 = transform_tensor_descriptor( constexpr auto d0_n0_n1_m0_m1_m2_m3 = transform_tensor_descriptor(
d0_raw_m0_n_m1, d0_raw_m0_n_m1,
make_tuple(make_unmerge_transform(make_tuple((D0M2 * D0M3) / I8, I2, I4 / D0M3)), make_tuple(make_unmerge_transform(make_tuple(D0M0 / I2, I2)),
make_unmerge_transform( make_unmerge_transform(
make_tuple(Number<NPerBlock / NPerXdl>{}, Number<NPerXdl>{})), make_tuple(Number<NPerBlock / NPerXdl>{}, Number<NPerXdl>{})),
make_pass_through_transform(D0M3)), make_pass_through_transform(D0M1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 3, 4>{}, Sequence<0, 1>{}, Sequence<5>{})); make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2_m3; return d0_n0_n1_m0_m1_m2_m3;
} }
static constexpr auto d0_block_desc_m0_n0_m1_m2_n1_m3 = static constexpr auto d0_block_desc_m0_n0_m1_m2_n1_m3 =
...@@ -1214,29 +1208,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1214,29 +1208,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3(); GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3();
static constexpr auto d0_thread_desc_ = static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, I4 / D0M3, D0M3)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I8, I1, D0M1));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1< using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<1, 1, 1, D0M2, NPerBlock, D0M3>, // BlockSliceLengths Sequence<D0M0, NPerBlock, D0M1>, // BlockSliceLengths
Sequence<1, 1, 1, 8, 32, 1>, // ThreadClusterLengths Sequence<8, 32, 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder Sequence<0, 2, 1>, // ThreadClusterArrangeOrder
D0DataType, // SrcData D0DataType, // SrcData
D0DataType, // DstData D0DataType, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_desc_m0_n0_m1_m2_n1_m3), // DstDesc decltype(d0_block_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder Sequence<0, 2, 1>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder Sequence<1, 0, 2>, // DstDimAccessOrder
4, // SrcVectorDim 1, // SrcVectorDim
2, // DstVectorDim 2, // DstVectorDim
NPerBlock / 32, // SrcScalarPerVector 1, // SrcScalarPerVector
D0M3.value / 1, // DstScalarPerVector 1, // DstScalarPerVector
1, 1,
1, 1,
false, true,
true, // DstResetCoord true, // DstResetCoord
1>; 1>;
...@@ -1245,11 +1239,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1245,11 +1239,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
D0DataType, // DstData D0DataType, // DstData
decltype(d0_block_desc_n0_n1_m0_m1_m2_m3), // SrcDesc decltype(d0_block_desc_n0_n1_m0_m1_m2_m3), // SrcDesc
decltype(d0_thread_desc_), // DstDesc decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 2, 2>, // SliceLengths Sequence<1, 1, 8, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4, 5>, // DimAccessOrder Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
5, // SrcVectorDim 4, // SrcVectorDim
D0M3.value, // SrcScalarPerVector 1, // SrcScalarPerVector
D0M3.value>; 1>;
}; };
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
...@@ -1739,13 +1733,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1739,13 +1733,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// D0 // D0
auto d0_block_copy_global_to_lds = auto d0_block_copy_global_to_lds =
typename D0Loader::D0BlockwiseCopy(d0_grid_desc_m0_n0_m1_m2_n1_m3, typename D0Loader::D0BlockwiseCopy(d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, block_work_idx_n, 0, 0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3, D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy( auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0, 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
ignore = d0_thread_copy_lds_to_vgpr; ignore = d0_thread_copy_lds_to_vgpr;
// //
// set up Y dot dY // set up Y dot dY
...@@ -1922,6 +1916,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1922,6 +1916,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
s_slash_p_thread_buf, s_slash_p_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// {
// auto p_lds = static_cast<D0DataType*>(p_shared);
// for(int idx = 0; idx < 32; idx++)
// {
// float tmp_gbl = ck::type_convert<float>(p_d0_grid[idx]);
// float tmp_lds = ck::type_convert<float>(p_lds[idx]);
// block_sync_lds();
// printf("p_d0_grid: %p, index: %d, gbl[idx]: %f, lds[idx]: %f \n",
// ck::type_convert<const void*>(p_d0_grid),
// idx,
// tmp_gbl,
// tmp_lds);
// }
// }
// 8d thread_desc in thread scope // 8d thread_desc in thread scope
constexpr auto c_thread_lengths = constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths(); s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
...@@ -1983,61 +1993,65 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1983,61 +1993,65 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// add bias // add bias
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
static constexpr auto c_thread_desc_ = // static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<16>{})); // make_naive_tensor_descriptor_packed(make_tuple(I2, Number<16>{}));
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Loader::d0_thread_desc_.GetElementSpaceSize()); D0Loader::d0_thread_desc_.GetElementSpaceSize());
ignore = d0_thread_buf; ignore = d0_thread_buf;
static_for<0, D0M1, 1>{}([&](auto mr) {
// load data to lds // load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_buf);
d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow( // d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); // d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3, d0_block_copy_global_to_lds.RunWrite(D0Loader::d0_block_desc_m0_n0_m1_m2_n1_m3,
d0_block_buf); d0_block_buf);
block_sync_lds(); block_sync_lds();
// read data form lds // read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_desc_n0_n1_m0_m1_m2_m3, d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_desc_n0_n1_m0_m1_m2_m3,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
d0_block_buf, d0_block_buf,
D0Loader::d0_thread_desc_, D0Loader::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf); d0_thread_buf);
// static_for<0, 2, 1>{}([&](auto mr) {
// bias add // bias add
static_for<0, D0Loader::d0_thread_desc_.GetElementSpaceSize(), 1>{}( static_for<0, s_slash_p_thread_buf.Size(), 1>{}([&](auto i) {
[&](auto i) { // constexpr index_t c_offset =
constexpr index_t c_offset = // c_thread_desc_.CalculateOffset(make_tuple(mr, i));
c_thread_desc_.CalculateOffset(make_tuple(mr, i));
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0) //if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// if(ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]) != 1.0f)
// { // {
// printf("c_offset: %d s_slash_p_thread_buf(Number<c_offset>{}): // float tmp_lds =
// %f, " // ck::type_convert<float>(static_cast<D0DataType*>(p_shared)[i.value]);
// "d0_thread_buf[i]: %f\n", // float tmp_gbl = ck::type_convert<float>(p_d0_grid[i.value]);
// c_offset, // block_sync_lds();
// s_slash_p_thread_buf(Number<c_offset>{}), // printf("id: %d : i: %d, gbl[i]: %f "
// ",lds[i]: %f d0_thread_buf[i]: %f\n",
// get_thread_local_1d_id(),
// i.value,
// tmp_gbl,
// tmp_lds,
// ck::type_convert<FloatGemmAcc>(d0_thread_buf[i])); // ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]));
// } // }
s_slash_p_thread_buf(i) += ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
}); });
}); //});
d0_block_copy_global_to_lds.MoveSrcSliceWindow( // d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(1, 0, -D0M1.value, 0, 0, 0)); // d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(1, 0, -D0M1.value, 0, 0,
// 0));
} }
// P_i: = softmax(scalar * S_i:) // P_i: = softmax(scalar * S_i:)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment