"scripts/convert_original_musicldm_to_diffusers.py" did not exist on "2625fb59dc7a3f03516f0c6c5c0cdda18ef0ef5b"
Commit 3f392b53 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Parameters renaming in batchnorm backward kernels and device op

parent d0b49a14
...@@ -35,8 +35,8 @@ struct DeviceBatchNormBwd : public BaseOperator ...@@ -35,8 +35,8 @@ struct DeviceBatchNormBwd : public BaseOperator
const void* p_savedInvVar, const void* p_savedInvVar,
double epsilon, double epsilon,
void* p_dx, void* p_dx,
void* p_scaleDiff, void* p_dscale,
void* p_biasDiff) = 0; void* p_dbias) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -201,8 +201,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -201,8 +201,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const MeanVarDataType* p_savedInvVar, const MeanVarDataType* p_savedInvVar,
double epsilon, double epsilon,
DxDataType* p_dx, DxDataType* p_dx,
ScaleDataType* p_scaleDiff, ScaleDataType* p_dscale,
BiasDataType* p_biasDiff) BiasDataType* p_dbias)
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides), bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides), bnBiasStrides_(bnBiasStrides),
...@@ -213,8 +213,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -213,8 +213,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
p_savedMean_(p_savedMean), p_savedMean_(p_savedMean),
p_savedInvVar_(p_savedInvVar), p_savedInvVar_(p_savedInvVar),
p_dx_(p_dx), p_dx_(p_dx),
p_scaleDiff_(p_scaleDiff), p_dscale_(p_dscale),
p_biasDiff_(p_biasDiff) p_dbias_(p_dbias)
{ {
xyLengths_ = xyLengths_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xyLengths, reduceDims); shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xyLengths, reduceDims);
...@@ -294,8 +294,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -294,8 +294,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const MeanVarDataType* p_savedMean_; const MeanVarDataType* p_savedMean_;
const MeanVarDataType* p_savedInvVar_; const MeanVarDataType* p_savedInvVar_;
DxDataType* p_dx_; DxDataType* p_dx_;
ScaleDataType* p_scaleDiff_; ScaleDataType* p_dscale_;
BiasDataType* p_biasDiff_; BiasDataType* p_dbias_;
long_index_t invariant_length; long_index_t invariant_length;
long_index_t reduce_length; long_index_t reduce_length;
...@@ -318,8 +318,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -318,8 +318,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
void* workspace_savedMean; void* workspace_savedMean;
void* workspace_savedInvVar; void* workspace_savedInvVar;
void* workspace_reduce_scale_diff; void* workspace_reduce_dscale;
void* workspace_reduce_bias_diff; void* workspace_reduce_dbias;
}; };
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
...@@ -372,14 +372,14 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -372,14 +372,14 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
index_t space_sz; index_t space_sz;
// setup buffer for the partial reduced result for scale_diff // setup buffer for the partial reduced result for scale_diff
pArg_->workspace_reduce_scale_diff = pArg_->p_workspace_; pArg_->workspace_reduce_dscale = pArg_->p_workspace_;
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType); space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType);
space_sz = math::integer_least_multiple(space_sz, 64); space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for the partial reduced result for bias_diff // setup buffer for the partial reduced result for bias_diff
pArg_->workspace_reduce_bias_diff = pArg_->workspace_reduce_dbias =
reinterpret_cast<char*>(pArg_->workspace_reduce_scale_diff) + space_sz; reinterpret_cast<char*>(pArg_->workspace_reduce_dscale) + space_sz;
if(UseMultiblockInK && pArg_->blkGroupSize > 1) if(UseMultiblockInK && pArg_->blkGroupSize > 1)
{ {
...@@ -388,7 +388,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -388,7 +388,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
// setup buffer for welford intermediate mean // setup buffer for welford intermediate mean
pArg_->workspace_mean = pArg_->workspace_mean =
reinterpret_cast<char*>(pArg_->workspace_reduce_bias_diff) + space_sz; reinterpret_cast<char*>(pArg_->workspace_reduce_dbias) + space_sz;
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType); space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType);
space_sz = math::integer_least_multiple(space_sz, 64); space_sz = math::integer_least_multiple(space_sz, 64);
...@@ -604,8 +604,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -604,8 +604,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
: static_cast<MeanVarDataType*>(arg.workspace_savedInvVar), : static_cast<MeanVarDataType*>(arg.workspace_savedInvVar),
arg.p_x_, arg.p_x_,
arg.p_dy_, arg.p_dy_,
static_cast<ScaleDataType*>(arg.workspace_reduce_scale_diff), static_cast<ScaleDataType*>(arg.workspace_reduce_dscale),
static_cast<BiasDataType*>(arg.workspace_reduce_bias_diff)); static_cast<BiasDataType*>(arg.workspace_reduce_dbias));
avg_time += launch_and_time_kernel( avg_time += launch_and_time_kernel(
stream_config, stream_config,
...@@ -624,8 +624,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -624,8 +624,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg.reduce_length, arg.reduce_length,
arg.numBlockTileIteration, arg.numBlockTileIteration,
numScaleBiasDiffBlockTileIteration, numScaleBiasDiffBlockTileIteration,
static_cast<const ScaleDataType*>(arg.workspace_reduce_scale_diff), static_cast<const ScaleDataType*>(arg.workspace_reduce_dscale),
static_cast<const BiasDataType*>(arg.workspace_reduce_bias_diff), static_cast<const BiasDataType*>(arg.workspace_reduce_dbias),
arg.haveSavedMeanInvVar_ arg.haveSavedMeanInvVar_
? arg.p_savedMean_ ? arg.p_savedMean_
: static_cast<const MeanVarDataType*>(arg.workspace_savedMean), : static_cast<const MeanVarDataType*>(arg.workspace_savedMean),
...@@ -636,8 +636,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -636,8 +636,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg.p_dy_, arg.p_dy_,
arg.p_scale_, arg.p_scale_,
arg.p_dx_, arg.p_dx_,
arg.p_scaleDiff_, arg.p_dscale_,
arg.p_biasDiff_); arg.p_dbias_);
} }
else else
{ {
...@@ -708,8 +708,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -708,8 +708,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg.p_savedMean_, arg.p_savedMean_,
arg.p_savedInvVar_, arg.p_savedInvVar_,
arg.p_dx_, arg.p_dx_,
arg.p_scaleDiff_, arg.p_dscale_,
arg.p_biasDiff_); arg.p_dbias_);
}; };
return (avg_time); return (avg_time);
...@@ -801,8 +801,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -801,8 +801,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const void* p_savedInvVar, const void* p_savedInvVar,
double epsilon, double epsilon,
void* p_dx, void* p_dx,
void* p_scaleDiff, void* p_dscale,
void* p_biasDiff) override void* p_dbias) override
{ {
return std::make_unique<Argument>(xyLengths, return std::make_unique<Argument>(xyLengths,
xStrides, xStrides,
...@@ -820,8 +820,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -820,8 +820,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
static_cast<const MeanVarDataType*>(p_savedInvVar), static_cast<const MeanVarDataType*>(p_savedInvVar),
epsilon, epsilon,
static_cast<DxDataType*>(p_dx), static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_scaleDiff), static_cast<ScaleDataType*>(p_dscale),
static_cast<BiasDataType*>(p_biasDiff)); static_cast<BiasDataType*>(p_dbias));
}; };
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
...@@ -34,16 +34,16 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final( ...@@ -34,16 +34,16 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
long_index_t reduce_size, long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_scale_bias_diff_k_block_tile_iteration, index_t num_scale_bias_diff_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_scale_diff, const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_bias_diff, const BiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean, const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var, const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy, const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale, const ScaleDataType* const __restrict__ p_scale,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_scale_diff, ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_bias_diff) BiasDataType* const __restrict__ p_dbias)
{ {
GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k, GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k, dy_grid_desc_m_k,
...@@ -56,16 +56,16 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final( ...@@ -56,16 +56,16 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
reduce_size, reduce_size,
num_xy_k_block_tile_iteration, num_xy_k_block_tile_iteration,
num_scale_bias_diff_k_block_tile_iteration, num_scale_bias_diff_k_block_tile_iteration,
p_reduce_scale_diff, p_reduce_dscale,
p_reduce_bias_diff, p_reduce_dbias,
p_mean, p_mean,
p_inv_var, p_inv_var,
p_x, p_x,
p_dy, p_dy,
p_scale, p_scale,
p_dx, p_dx,
p_scale_diff, p_dscale,
p_bias_diff); p_dbias);
}; };
template <typename XDataType, template <typename XDataType,
...@@ -151,16 +151,16 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -151,16 +151,16 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
long_index_t reduce_size, long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_scale_bias_diff_k_block_tile_iteration, index_t num_scale_bias_diff_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_scale_diff, const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_bias_diff, const BiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean, const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var, const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy, const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale, const ScaleDataType* const __restrict__ p_scale,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_scale_diff, ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_bias_diff) BiasDataType* const __restrict__ p_dbias)
{ {
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
...@@ -281,16 +281,16 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -281,16 +281,16 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
PassThroughOp{}); PassThroughOp{});
const auto reduce_scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto reduce_scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_scale_diff, scale_bias_diff_grid_desc_m_k.GetElementSpaceSize()); p_reduce_dscale, scale_bias_diff_grid_desc_m_k.GetElementSpaceSize());
const auto reduce_bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto reduce_bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_bias_diff, scale_bias_diff_grid_desc_m_k.GetElementSpaceSize()); p_reduce_dbias, scale_bias_diff_grid_desc_m_k.GetElementSpaceSize());
auto scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale_diff, scale_grid_desc_m.GetElementSpaceSize()); p_dscale, scale_grid_desc_m.GetElementSpaceSize());
auto bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_diff, bias_grid_desc_m.GetElementSpaceSize()); p_dbias, bias_grid_desc_m.GetElementSpaceSize());
constexpr auto scale_bias_diff_thread_copy_step_m_k = constexpr auto scale_bias_diff_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1); make_multi_index(0, KThreadClusterSize * 1);
......
...@@ -43,8 +43,8 @@ __global__ void kernel_welford_second_half_reduce_first_half( ...@@ -43,8 +43,8 @@ __global__ void kernel_welford_second_half_reduce_first_half(
MeanVarDataType* const __restrict__ p_out_welford_inv_variance, MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy, const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_scale_diff, ScaleDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_bias_diff) BiasDataType* const __restrict__ p_reduce_dbias)
{ {
GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k, GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k, dy_grid_desc_m_k,
...@@ -65,8 +65,8 @@ __global__ void kernel_welford_second_half_reduce_first_half( ...@@ -65,8 +65,8 @@ __global__ void kernel_welford_second_half_reduce_first_half(
p_out_welford_inv_variance, p_out_welford_inv_variance,
p_x, p_x,
p_dy, p_dy,
p_reduce_scale_diff, p_reduce_dscale,
p_reduce_bias_diff); p_reduce_dbias);
}; };
template <typename XDataType, template <typename XDataType,
...@@ -164,8 +164,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -164,8 +164,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
MeanVarDataType* const __restrict__ p_out_welford_inv_variance, MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy, const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_scale_diff, ScaleDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_bias_diff) BiasDataType* const __restrict__ p_reduce_dbias)
{ {
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
...@@ -531,10 +531,10 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -531,10 +531,10 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
PassThroughOp{}); PassThroughOp{});
auto reduce_scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto reduce_scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_scale_diff, scale_bias_diff_grid_desc_m_g.GetElementSpaceSize()); p_reduce_dscale, scale_bias_diff_grid_desc_m_g.GetElementSpaceSize());
auto reduce_bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto reduce_bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_bias_diff, scale_bias_diff_grid_desc_m_g.GetElementSpaceSize()); p_reduce_dbias, scale_bias_diff_grid_desc_m_g.GetElementSpaceSize());
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
......
...@@ -45,8 +45,8 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford( ...@@ -45,8 +45,8 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
const MeanVarDataType* const __restrict__ p_savedMean, const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar, const MeanVarDataType* const __restrict__ p_savedInvVar,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_scale_diff, ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_bias_diff) BiasDataType* const __restrict__ p_dbias)
{ {
GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k, GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k, dy_grid_desc_m_k,
...@@ -65,8 +65,8 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford( ...@@ -65,8 +65,8 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
p_savedMean, p_savedMean,
p_savedInvVar, p_savedInvVar,
p_dx, p_dx,
p_scale_diff, p_dscale,
p_bias_diff); p_dbias);
}; };
template <typename XDataType, template <typename XDataType,
...@@ -166,8 +166,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -166,8 +166,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
const MeanVarDataType* const __restrict__ p_savedMean, const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar, const MeanVarDataType* const __restrict__ p_savedInvVar,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_scale_diff, ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_bias_diff) BiasDataType* const __restrict__ p_dbias)
{ {
using ck::math::sqrt; using ck::math::sqrt;
...@@ -333,10 +333,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -333,10 +333,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
p_scale, scale_grid_desc_m.GetElementSpaceSize()); p_scale, scale_grid_desc_m.GetElementSpaceSize());
auto scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale_diff, scale_grid_desc_m.GetElementSpaceSize()); p_dscale, scale_grid_desc_m.GetElementSpaceSize());
auto bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_diff, bias_grid_desc_m.GetElementSpaceSize()); p_dbias, bias_grid_desc_m.GetElementSpaceSize());
if(haveSavedMeanInvVar) if(haveSavedMeanInvVar)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment