Commit 1970d162 authored by fsx950223's avatar fsx950223
Browse files

Merge remote-tracking branch 'origin/attn-train-develop-qloop' into skip_dropout

parents 5a3904c7 9b4c780a
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as: Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as: Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as: Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...@@ -1396,7 +1396,7 @@ int run(int argc, char* argv[]) ...@@ -1396,7 +1396,7 @@ int run(int argc, char* argv[])
} }
std::cout << "Checking z:\n"; std::cout << "Checking z:\n";
pass &= ck::utils::check_err(z_fwd_gs_ms_ns.mData, z_bwd_gs_ms_ns.mData, 1); pass &= ck::utils::check_integer_err(z_fwd_gs_ms_ns.mData, z_bwd_gs_ms_ns.mData, 1);
std::cout << "Checking y:\n"; std::cout << "Checking y:\n";
pass &= ck::utils::check_err( pass &= ck::utils::check_err(
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...@@ -969,7 +969,7 @@ int run(int argc, char* argv[]) ...@@ -969,7 +969,7 @@ int run(int argc, char* argv[])
} }
std::cout << "Checking z:\n"; std::cout << "Checking z:\n";
pass &= ck::utils::check_err(z_fwd_gs_ms_ns.mData, z_bwd_gs_ms_ns.mData, 1); pass &= ck::utils::check_integer_err(z_fwd_gs_ms_ns.mData, z_bwd_gs_ms_ns.mData, 1);
std::cout << "Checking y:\n"; std::cout << "Checking y:\n";
pass &= ck::utils::check_err( pass &= ck::utils::check_err(
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as: Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as: Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...@@ -1420,7 +1420,7 @@ int run(int argc, char* argv[]) ...@@ -1420,7 +1420,7 @@ int run(int argc, char* argv[])
} }
std::cout << "Checking z:\n"; std::cout << "Checking z:\n";
pass &= ck::utils::check_err(z_fwd_tensors[i].mData, z_bwd_tensors[i].mData, 1); pass &= ck::utils::check_integer_err(z_fwd_tensors[i].mData, z_bwd_tensors[i].mData, 1);
std::cout << "Checking y:\n"; std::cout << "Checking y:\n";
pass &= ck::utils::check_err( pass &= ck::utils::check_err(
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...@@ -994,7 +994,7 @@ int run(int argc, char* argv[]) ...@@ -994,7 +994,7 @@ int run(int argc, char* argv[])
} }
std::cout << "Checking z:\n"; std::cout << "Checking z:\n";
pass &= ck::utils::check_err(z_fwd_tensors[i].mData, z_bwd_tensors[i].mData, 1); pass &= ck::utils::check_integer_err(z_fwd_tensors[i].mData, z_bwd_tensors[i].mData, 1);
std::cout << "Checking y:\n"; std::cout << "Checking y:\n";
pass &= ck::utils::check_err( pass &= ck::utils::check_err(
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[]) int run(int argc, char* argv[])
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[]) int run(int argc, char* argv[])
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -89,7 +89,8 @@ __global__ void ...@@ -89,7 +89,8 @@ __global__ void
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -641,7 +642,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -641,7 +642,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -1050,7 +1051,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -1050,7 +1051,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
arg.Print(); arg.Print();
#endif #endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -88,7 +88,8 @@ __global__ void ...@@ -88,7 +88,8 @@ __global__ void
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -640,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2 ...@@ -640,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -1051,7 +1052,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2 ...@@ -1051,7 +1052,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
arg.Print(); arg.Print();
#endif #endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -87,7 +87,8 @@ __global__ void ...@@ -87,7 +87,8 @@ __global__ void
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -637,7 +638,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1 ...@@ -637,7 +638,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -1044,7 +1045,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1 ...@@ -1044,7 +1045,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
arg.Print(); arg.Print();
#endif #endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -90,7 +90,8 @@ __global__ void ...@@ -90,7 +90,8 @@ __global__ void
const index_t raw_m_padded, const index_t raw_m_padded,
const index_t raw_n_padded) const index_t raw_n_padded)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -628,7 +629,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -628,7 +629,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -1024,7 +1025,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1024,7 +1025,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -89,7 +89,8 @@ __global__ void ...@@ -89,7 +89,8 @@ __global__ void
const index_t raw_m_padded, const index_t raw_m_padded,
const index_t raw_n_padded) const index_t raw_n_padded)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -634,7 +635,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -634,7 +635,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2<
InputDataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType, OutputDataType,
ZDataType, ZDataType,
...@@ -1055,7 +1056,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1055,7 +1056,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment