Commit 3552041a authored by danyao12's avatar danyao12
Browse files

Merge branch 'develop' into ck_tile/fa_bwd_opt

parents e8927110 733f33af
...@@ -290,7 +290,9 @@ class FmhaBwdApiPool: ...@@ -290,7 +290,9 @@ class FmhaBwdApiPool:
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
# GEMM0: Q@K=S^T # GEMM0: Q@K=S^T
......
...@@ -279,6 +279,9 @@ class FmhaFwdApiPool: ...@@ -279,6 +279,9 @@ class FmhaFwdApiPool:
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
@dataclass @dataclass
......
...@@ -332,6 +332,9 @@ class FmhaFwdSplitKVApiPool: ...@@ -332,6 +332,9 @@ class FmhaFwdSplitKVApiPool:
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes) return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes)
@dataclass @dataclass
......
File mode changed from 100644 to 100755
...@@ -8,7 +8,7 @@ export CK_WARMUP=0 ...@@ -8,7 +8,7 @@ export CK_WARMUP=0
export CK_REPEAT=1 export CK_REPEAT=1
COMMON_ARGS='-v=1' COMMON_ARGS='-v=1'
set -x
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do for perm in 0 1 ; do
for hdim in 32 64 128 256 ; do for hdim in 32 64 128 256 ; do
...@@ -33,3 +33,4 @@ done ...@@ -33,3 +33,4 @@ done
done done
done done
done done
set +x
...@@ -10,7 +10,7 @@ export CK_REPEAT=1 ...@@ -10,7 +10,7 @@ export CK_REPEAT=1
COMMON_ARGS='-v=1 -warmup=0 -repeat=1' COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
# mode=0 # mode=0
# export HIP_VISIBLE_DEVICES=4 # export HIP_VISIBLE_DEVICES=4
set -x
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
for mode in 1 0 ; do for mode in 1 0 ; do
for perm in 0 1 ; do for perm in 0 1 ; do
...@@ -40,6 +40,7 @@ done ...@@ -40,6 +40,7 @@ done
done done
done done
for perm in 0 1 ; do for perm in 0 1 ; do
for bias in "n" "e" "a" ; do for bias in "n" "e" "a" ; do
for b in 1 2 ; do for b in 1 2 ; do
...@@ -49,3 +50,4 @@ done ...@@ -49,3 +50,4 @@ done
done done
done done
done done
set +x
This diff is collapsed.
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -15,6 +15,7 @@ enum struct ConvolutionForwardSpecialization ...@@ -15,6 +15,7 @@ enum struct ConvolutionForwardSpecialization
Filter1x1Pad0, Filter1x1Pad0,
Filter1x1Stride1Pad0, Filter1x1Stride1Pad0,
OddC, OddC,
Filter3x3,
}; };
inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s) inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s)
...@@ -25,6 +26,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp ...@@ -25,6 +26,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp
case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
case ConvolutionForwardSpecialization::OddC: return "OddC"; case ConvolutionForwardSpecialization::OddC: return "OddC";
case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3";
default: return "Unrecognized specialization!"; default: return "Unrecognized specialization!";
} }
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename AScaleType,
typename BDataType,
typename BScaleType,
typename DsDataType,
typename EDataType,
index_t ScaleBlockM,
index_t ScaleBlockN,
index_t ScaleBlockK,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleD_ABScale : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const ck::index_t StrideA,
const ck::index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
const ck::index_t StrideE,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment