Unverified Commit 408534d4 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-1815

parents a8efb3f0 da214a5a
......@@ -57,8 +57,8 @@ struct DeviceImageToColumnImpl
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>{};
using ConvToGemmFwdTransformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
......@@ -97,9 +97,7 @@ struct DeviceImageToColumnImpl
b_g_k_c_xs_lengths[I2] = C;
c_g_n_k_wos_lengths[I1] = N;
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>(
a_g_n_c_wis_lengths,
ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
image_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
......@@ -108,8 +106,10 @@ struct DeviceImageToColumnImpl
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
N);
input_right_pads};
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
......@@ -638,6 +638,32 @@ struct AddSilu
}
};
struct ConvScaleAdd
{
__host__ __device__ ConvScaleAdd(float scale_in = 1.f,
float scale_wei = 1.f,
float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C, typename D>
__host__ __device__ void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ void
operator()<f8_t, float, float>(f8_t& e, const float& c, const float& d) const
{
float x;
Add{}.template operator()<float>(x, c * scale_in_ * scale_wei_, d);
e = type_convert<f8_t>(x * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -258,7 +258,7 @@ struct GridwiseMultipleReduction_mk_to_m_multiblock
if(thread_k_cluster_id == 0)
{
if(block_group_size == 0 && !float_equal_zero{}(beta_values[iR]))
if(!float_equal_zero{}(beta_values[iR]))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment