Commit 9526b9ec authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent 12585e57
......@@ -39,7 +39,7 @@ struct DeviceConvFwdMultipleD : public BaseOperator
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
......
......@@ -716,7 +716,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
Argument(
const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
......@@ -910,18 +910,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
arg.block_2_etile_map_);
};
float avg_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
avg_time = launch_kernel(integral_constant<bool, true>{});
return launch_kernel(integral_constant<bool, true>{});
}
else
{
avg_time = launch_kernel(integral_constant<bool, false>{});
return launch_kernel(integral_constant<bool, false>{});
}
return avg_time;
}
float Run(const BaseArgument* p_arg,
......@@ -935,6 +931,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
{
namespace ctc = tensor_layout::convolution;
// check device
if(get_device_name() == "gfx908")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
......@@ -956,12 +953,12 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false;
}
// check tensor size: can't be larger than 2GB each
constexpr long_index_t GB2 = (long_index_t{1} << 31);
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(arg.a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() * sizeof(ADataType) > GB2 ||
arg.b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * sizeof(BDataType) > GB2 ||
arg.e_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) > GB2)
if(arg.a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() * sizeof(ADataType) > TwoGB ||
arg.b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * sizeof(BDataType) > TwoGB ||
arg.e_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) > TwoGB)
{
return false;
}
......@@ -1066,7 +1063,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static auto MakeArgument(
const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
......@@ -1110,7 +1107,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment