Commit 67ab3896 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

Merge branch 'develop' into gemm_getname

parents 8adaf418 d5c8a334
......@@ -77,6 +77,43 @@ struct DeviceGemmV2R1 : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename BScaleType,
typename CDataType,
index_t ScaleBlockN,
index_t ScaleBlockK,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmV2BScale : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t StrideScaleB,
const void* p_b_scale,
ck::index_t KSplit,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual bool GetPermuteB() = 0;
virtual ck::index_t GetKPerBlock() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -469,7 +469,11 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> &&
arg.Streamk_sel > 0)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment