Commit f90af872 authored by letaoqin's avatar letaoqin
Browse files

fix merge issue

parent 63e3f3c4
......@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8.
#define DIM 128 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......
......@@ -1332,7 +1332,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
......@@ -1361,7 +1361,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
true, // DstResetCoord
1>;
using D0ThreadWiseCopy =
using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc
......@@ -1372,7 +1372,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
2, // SrcScalarPerVector
2>;
using D0ThreadCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
using D0ThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
......@@ -2054,7 +2054,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0
auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopy(
auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
......@@ -2062,10 +2062,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadCopy(
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadCopyVgprToLds(
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadwiseCopyVgprToLds(
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
......
......@@ -1388,7 +1388,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
......@@ -1417,7 +1417,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true, // DstResetCoord
1>;
using D0ThreadWiseCopy =
using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc
......@@ -1511,7 +1511,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
sizeof(GemmDataType) / sizeof(FloatGemmAcc);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Operator::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(),
max_lds_align);
static constexpr auto d0_block_space_offset =
k_block_space_size_aligned.value * sizeof(GemmDataType) /
D0Operator::template TypeTransform<D0DataType>::Size;
......@@ -2126,7 +2127,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0
auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopy(
auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
......@@ -2134,7 +2135,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadCopy(
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadCopyVgprToLds(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment