Commit a59e8d48 authored by ltqin's avatar ltqin
Browse files

Merge branch 'develop' into mha-train-develop

parents f0984ca5 7761e523
...@@ -70,10 +70,9 @@ __global__ void ...@@ -70,10 +70,9 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
// TODO: Enable for gfx90a after complier fix #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \ defined(__gfx1101__) || defined(__gfx1102__))
defined(__gfx1102__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -650,10 +649,10 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -650,10 +649,10 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// TODO: Enable for gfx90a after complier fix // TODO: Enable for gfx90a after complier fix
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" || if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" ||
ck::get_device_name() == "gfx1102") ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102")
{ {
bool pass = true; bool pass = true;
pass = pass && arg.K_ % K1 == 0; pass = pass && arg.K_ % K1 == 0;
......
...@@ -809,9 +809,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -809,9 +809,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!ck::is_xdl_supported())
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -801,6 +801,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO ...@@ -801,6 +801,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -727,9 +727,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -727,9 +727,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg.Print(); arg.Print();
#endif #endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!ck::is_xdl_supported())
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -613,9 +613,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -613,9 +613,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!ck::is_xdl_supported())
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -310,6 +310,11 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -310,6 +310,11 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
static bool IsSupportedArgument(const Problem& problem) static bool IsSupportedArgument(const Problem& problem)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if(problem.K % K1 != 0) if(problem.K % K1 != 0)
{ {
return false; return false;
......
...@@ -448,6 +448,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -448,6 +448,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
return GridwiseGemm::CheckValidity(arg); return GridwiseGemm::CheckValidity(arg);
} }
......
...@@ -582,9 +582,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -582,9 +582,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!ck::is_xdl_supported())
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -649,6 +649,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -649,6 +649,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
// vector load A/B matrix from global memory // vector load A/B matrix from global memory
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
......
...@@ -616,6 +616,11 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -616,6 +616,11 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
......
...@@ -810,6 +810,11 @@ struct ...@@ -810,6 +810,11 @@ struct
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
......
...@@ -767,6 +767,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -767,6 +767,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
......
...@@ -741,6 +741,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -741,6 +741,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
......
...@@ -524,6 +524,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -524,6 +524,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
......
...@@ -524,6 +524,11 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -524,6 +524,11 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -1320,6 +1320,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1320,6 +1320,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{ {
......
...@@ -683,6 +683,11 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO ...@@ -683,6 +683,11 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -855,9 +855,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -855,9 +855,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!ck::is_xdl_supported())
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -555,9 +555,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -555,9 +555,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!ck::is_xdl_supported())
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
...@@ -491,9 +491,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -491,9 +491,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || if(!ck::is_xdl_supported())
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment