Commit 548ba6f4 authored by Chao Liu's avatar Chao Liu
Browse files

use scale

parent c961ce92
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
......@@ -51,7 +51,7 @@ using CLayout = Row;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
......@@ -122,7 +122,7 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
AccDataType,
AElementOp,
B0ElementOp,
CElementOp>;
Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance =
......@@ -157,6 +157,7 @@ int main(int argc, char* argv[])
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideC = -1;
float alpha = 1;
if(argc == 1)
{
......@@ -181,7 +182,7 @@ int main(int argc, char* argv[])
BatchCount = std::stoi(argv[8]);
}
else if(argc == 17)
else if(argc == 18)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
......@@ -203,6 +204,8 @@ int main(int argc, char* argv[])
BatchStrideB0 = std::stoi(argv[14]);
BatchStrideB1 = std::stoi(argv[15]);
BatchStrideC = std::stoi(argv[16]);
alpha = std::stof(argv[17]);
}
else
{
......@@ -211,6 +214,7 @@ int main(int argc, char* argv[])
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
"BatchStrideB0, BatchStrideB1, BatchStrideC\n");
printf("arg18: alpha\n");
exit(0);
}
......@@ -304,7 +308,7 @@ int main(int argc, char* argv[])
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
......@@ -368,7 +372,7 @@ int main(int argc, char* argv[])
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, PassThrough{});
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
......
# TODO: add example batched_gemm_gemm_xdl_fp16
add_example_executable(example_batched_gemm_softmax_gemm_xdl_fp16 batched_gemm_softmax_gemm_xdl_fp16.cpp)
......@@ -46,7 +46,7 @@ add_subdirectory(28_grouped_gemm_bias_e_permute)
add_subdirectory(29_batched_gemm_bias_e_permute)
add_subdirectory(30_grouped_convnd_fwd_bias_relu_add)
add_subdirectory(31_batched_gemm_gemm)
add_subdirectory(32_batched_gemm_softmax_gemm)
add_subdirectory(32_batched_gemm_scale_softmax_gemm)
add_subdirectory(33_multiple_reduce)
add_subdirectory(34_batchnorm)
......@@ -561,11 +561,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatAB,
decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1),
decltype(acc_element_op),
tensor_operation::element_wise::PassThrough,
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>,
2,
n4>{acc_element_op};
n4>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy
auto b1_blockwise_copy =
......@@ -717,6 +717,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
blockwise_gemm,
acc_thread_buf,
num_k_block_main_loop);
// Acc0 elementwise Op
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
// softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment