Commit 9526b9ec authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent 12585e57
...@@ -39,7 +39,7 @@ struct DeviceConvFwdMultipleD : public BaseOperator ...@@ -39,7 +39,7 @@ struct DeviceConvFwdMultipleD : public BaseOperator
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, const std::array<const void*, NumDTensor>& p_ds,
void* p_e, void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides, const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
......
...@@ -716,7 +716,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -716,7 +716,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
Argument( Argument(
const void* p_a, const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, const std::array<const void*, NumDTensor>& p_ds,
void* p_e, void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides, const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
...@@ -910,18 +910,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -910,18 +910,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
arg.block_2_etile_map_); arg.block_2_etile_map_);
}; };
float avg_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
avg_time = launch_kernel(integral_constant<bool, true>{}); return launch_kernel(integral_constant<bool, true>{});
} }
else else
{ {
avg_time = launch_kernel(integral_constant<bool, false>{}); return launch_kernel(integral_constant<bool, false>{});
} }
return avg_time;
} }
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
...@@ -935,6 +931,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -935,6 +931,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
{ {
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
// check device
if(get_device_name() == "gfx908") if(get_device_name() == "gfx908")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> || if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
...@@ -956,12 +953,12 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -956,12 +953,12 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return false; return false;
} }
// check tensor size: can't be larger than 2GB each // check tensor size: cannot be larger than 2GB each
constexpr long_index_t GB2 = (long_index_t{1} << 31); constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(arg.a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() * sizeof(ADataType) > GB2 || if(arg.a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() * sizeof(ADataType) > TwoGB ||
arg.b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * sizeof(BDataType) > GB2 || arg.b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * sizeof(BDataType) > TwoGB ||
arg.e_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) > GB2) arg.e_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) > TwoGB)
{ {
return false; return false;
} }
...@@ -1066,7 +1063,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1066,7 +1063,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static auto MakeArgument( static auto MakeArgument(
const void* p_a, const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, const std::array<const void*, NumDTensor>& p_ds,
void* p_e, void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides, const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
...@@ -1110,7 +1107,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1110,7 +1107,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
std::unique_ptr<BaseArgument> MakeArgumentPointer( std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, const std::array<const void*, NumDTensor>& p_ds,
void* p_e, void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides, const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment