"vscode:/vscode.git/clone" did not exist on "2ce80194806f73c1b7ced1d686ce01efd3aefdc7"
Commit 1970d162 authored by fsx950223's avatar fsx950223
Browse files

Merge remote-tracking branch 'origin/attn-train-develop-qloop' into skip_dropout

parents 5a3904c7 9b4c780a
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -81,7 +81,8 @@ __global__ void ...@@ -81,7 +81,8 @@ __global__ void
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -449,7 +450,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -449,7 +450,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
ZDataType, ZDataType,
GemmDataType, GemmDataType,
...@@ -898,7 +899,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -898,7 +899,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
arg.Print(); arg.Print();
#endif #endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -83,7 +83,8 @@ __global__ void ...@@ -83,7 +83,8 @@ __global__ void
const index_t raw_m_padded, const index_t raw_m_padded,
const index_t raw_n_padded) const index_t raw_n_padded)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -457,7 +458,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -457,7 +458,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
ZDataType, ZDataType,
GemmDataType, GemmDataType,
...@@ -915,7 +916,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -915,7 +916,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg.Print(); arg.Print();
#endif #endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -52,7 +52,8 @@ __global__ void ...@@ -52,7 +52,8 @@ __global__ void
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
...@@ -576,7 +577,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -576,7 +577,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -1020,7 +1021,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -1020,7 +1021,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -52,7 +52,8 @@ __global__ void ...@@ -52,7 +52,8 @@ __global__ void
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
...@@ -569,7 +570,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2 ...@@ -569,7 +570,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -1012,7 +1013,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2 ...@@ -1012,7 +1013,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -53,7 +53,8 @@ __global__ void ...@@ -53,7 +53,8 @@ __global__ void
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
...@@ -574,7 +575,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -574,7 +575,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -1033,7 +1034,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1033,7 +1034,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -53,7 +53,8 @@ __global__ void ...@@ -53,7 +53,8 @@ __global__ void
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
...@@ -574,7 +575,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -574,7 +575,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -1040,7 +1041,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1040,7 +1041,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -52,7 +52,8 @@ __global__ void ...@@ -52,7 +52,8 @@ __global__ void
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
...@@ -465,7 +466,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -465,7 +466,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
ZDataType, ZDataType,
GemmDataType, GemmDataType,
...@@ -938,7 +939,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -938,7 +939,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -52,7 +52,8 @@ __global__ void ...@@ -52,7 +52,8 @@ __global__ void
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
...@@ -471,7 +472,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -471,7 +472,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
ZDataType, ZDataType,
GemmDataType, GemmDataType,
...@@ -960,7 +961,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -960,7 +961,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -88,7 +88,7 @@ template <typename InputDataType, ...@@ -88,7 +88,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic, bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -1943,8 +1943,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1943,8 +1943,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
if(c0_matrix_mask.IsMaskedElement(m_global, n_global)) if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
{ {
s_slash_p_thread_buf(i) = -ck::NumericLimits<float>::Infinity(); s_slash_p_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
...@@ -2086,6 +2086,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2086,6 +2086,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
p_slice_idx[I3], p_slice_idx[I3],
p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3)); p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run( vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run(
...@@ -2096,8 +2097,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2096,8 +2097,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
gemm2_a_block_buf); gemm2_a_block_buf);
} }
// block_sync_lds(); // sync before write
vgrad_gemm_tile_ygrad_blockwise_copy.Run(Gemm2::b_block_desc_o0_o1_o2_m0_m1_m2_m3, vgrad_gemm_tile_ygrad_blockwise_copy.Run(Gemm2::b_block_desc_o0_o1_o2_m0_m1_m2_m3,
ygrad_block_buf, ygrad_block_buf,
Gemm2::b_thread_desc_o0_o1_o2_m0_m1_m2_m3, Gemm2::b_thread_desc_o0_o1_o2_m0_m1_m2_m3,
...@@ -2135,6 +2134,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2135,6 +2134,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
sgrad_slice_idx[I3] + sgrad_slice_idx[I3] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3)); Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run( kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -88,7 +88,7 @@ template <typename InputDataType, ...@@ -88,7 +88,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic, bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -1448,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1448,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1, 1,
false>{ false>{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx_m, // mblock make_multi_index(block_work_idx_m, // mblock
acc0_thread_origin[I0], // mrepeat acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl acc0_thread_origin[I4])}; // mperxdl
...@@ -1511,14 +1511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1511,14 +1511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId 0, // NBlockId
0, // mrepeat 0, // mrepeat
0, // nrepeat 0, // nrepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
wave_id[I1], // NWaveId wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl wave_m_n_id[I1], // MPerXdl
0, // group 0, // group
wave_m_n_id[I0], // NInputIndex wave_m_n_id[I0], // NInputIndex
0), 0),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
...@@ -1838,16 +1838,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1838,16 +1838,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
if(c0_matrix_mask.IsMaskedElement(m_global, n_global)) bool masked_flag = c0_matrix_mask.IsMaskedElement(m_global, n_global);
{ s_element_op(s_slash_p_thread_buf(i),
s_slash_p_thread_buf(i) = -ck::NumericLimits<float>::Infinity(); masked_flag ? -ck::NumericLimits<float>::Infinity()
} : s_slash_p_thread_buf[i]);
else
{
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
}
}); });
} }
else else
...@@ -1924,6 +1920,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1924,6 +1920,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
p_slice_idx[I3], p_slice_idx[I3],
p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3)); p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run( vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run(
...@@ -1939,7 +1936,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1939,7 +1936,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow( vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1, Gemm2::b_block_slice_copy_step); ygrad_grid_desc_m0_o_m1, Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1, vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
gemm2_b_block_buf); gemm2_b_block_buf);
...@@ -1987,17 +1983,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1987,17 +1983,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
constexpr auto m = constexpr auto m =
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0]; pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
// dS and P has same thread buf layout // dS and P has same thread buf layout
if(s_slash_p_thread_buf[i] >= 0) bool undropped_flag = s_slash_p_thread_buf[i] >= 0;
{ sgrad_thread_buf(i) =
sgrad_thread_buf(i) = s_slash_p_thread_buf[i] *
s_slash_p_thread_buf[i] * (undropped_flag ? (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}])
(pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]); : y_dot_ygrad_thread_buf[Number<m>{}]);
}
else
{
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] * y_dot_ygrad_thread_buf[Number<m>{}];
}
}); });
// gemm dQ // gemm dQ
...@@ -2082,6 +2072,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2082,6 +2072,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
sgrad_slice_idx[I3] + sgrad_slice_idx[I3] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3)); Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run( kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
...@@ -2098,7 +2089,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2098,7 +2089,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1, kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
Gemm2::b_block_slice_copy_step); Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1, kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
gemm2_b_block_buf); gemm2_b_block_buf);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -81,7 +81,7 @@ template <typename InputDataType, ...@@ -81,7 +81,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic, bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -89,7 +89,7 @@ template <typename InputDataType, ...@@ -89,7 +89,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic, bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -89,7 +89,7 @@ template <typename InputDataType, ...@@ -89,7 +89,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic, bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -81,7 +81,7 @@ template <typename InputDataType, ...@@ -81,7 +81,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic, bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -89,7 +89,7 @@ template <typename InputDataType, ...@@ -89,7 +89,7 @@ template <typename InputDataType,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic, bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -87,7 +87,7 @@ template <typename FloatAB, ...@@ -87,7 +87,7 @@ template <typename FloatAB,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic, bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -87,7 +87,7 @@ template <typename FloatAB, ...@@ -87,7 +87,7 @@ template <typename FloatAB,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic, bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
......
...@@ -219,7 +219,7 @@ typename std::enable_if< ...@@ -219,7 +219,7 @@ typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> && std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, unsigned short>, std::is_same_v<ranges::range_value_t<Range>, unsigned short>,
bool>::type bool>::type
check_err(const Range& out, const RefRange& ref, unsigned short atol = 1) check_integer_err(const Range& out, const RefRange& ref, unsigned short atol)
{ {
const std::string& msg = "Error: Incorrect U16 results!"; const std::string& msg = "Error: Incorrect U16 results!";
if(out.size() != ref.size()) if(out.size() != ref.size())
...@@ -262,7 +262,7 @@ typename std::enable_if< ...@@ -262,7 +262,7 @@ typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> && std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, int32_t>, std::is_same_v<ranges::range_value_t<Range>, int32_t>,
bool>::type bool>::type
check_err(const Range& out, const RefRange& ref, int32_t atol = 1) check_integer_err(const Range& out, const RefRange& ref, int32_t atol)
{ {
const std::string& msg = "Error: Incorrect U16 results!"; const std::string& msg = "Error: Incorrect U16 results!";
if(out.size() != ref.size()) if(out.size() != ref.size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment