Unverified Commit bac7df8f authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

use scale (#363)

parent c961ce92
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
...@@ -51,7 +51,7 @@ using CLayout = Row; ...@@ -51,7 +51,7 @@ using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using B0ElementOp = PassThrough; using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough; using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
...@@ -122,7 +122,7 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -122,7 +122,7 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
AccDataType, AccDataType,
AElementOp, AElementOp,
B0ElementOp, B0ElementOp,
CElementOp>; Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out // Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance = using ReferenceSoftmaxInstance =
...@@ -157,6 +157,7 @@ int main(int argc, char* argv[]) ...@@ -157,6 +157,7 @@ int main(int argc, char* argv[])
ck::index_t BatchStrideB0 = -1; ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1; ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideC = -1; ck::index_t BatchStrideC = -1;
float alpha = 1;
if(argc == 1) if(argc == 1)
{ {
...@@ -181,7 +182,7 @@ int main(int argc, char* argv[]) ...@@ -181,7 +182,7 @@ int main(int argc, char* argv[])
BatchCount = std::stoi(argv[8]); BatchCount = std::stoi(argv[8]);
} }
else if(argc == 17) else if(argc == 18)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -203,6 +204,8 @@ int main(int argc, char* argv[]) ...@@ -203,6 +204,8 @@ int main(int argc, char* argv[])
BatchStrideB0 = std::stoi(argv[14]); BatchStrideB0 = std::stoi(argv[14]);
BatchStrideB1 = std::stoi(argv[15]); BatchStrideB1 = std::stoi(argv[15]);
BatchStrideC = std::stoi(argv[16]); BatchStrideC = std::stoi(argv[16]);
alpha = std::stof(argv[17]);
} }
else else
{ {
...@@ -211,6 +214,7 @@ int main(int argc, char* argv[]) ...@@ -211,6 +214,7 @@ int main(int argc, char* argv[])
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, " printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
"BatchStrideB0, BatchStrideB1, BatchStrideC\n"); "BatchStrideB0, BatchStrideB1, BatchStrideC\n");
printf("arg18: alpha\n");
exit(0); exit(0);
} }
...@@ -304,7 +308,7 @@ int main(int argc, char* argv[]) ...@@ -304,7 +308,7 @@ int main(int argc, char* argv[])
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{}; auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{}; auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
...@@ -368,7 +372,7 @@ int main(int argc, char* argv[]) ...@@ -368,7 +372,7 @@ int main(int argc, char* argv[])
auto ref_gemm0 = ReferenceGemm0Instance{}; auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument( auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, PassThrough{}); a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
......
# TODO: add example batched_gemm_gemm_xdl_fp16
add_example_executable(example_batched_gemm_softmax_gemm_xdl_fp16 batched_gemm_softmax_gemm_xdl_fp16.cpp)
...@@ -46,7 +46,7 @@ add_subdirectory(28_grouped_gemm_bias_e_permute) ...@@ -46,7 +46,7 @@ add_subdirectory(28_grouped_gemm_bias_e_permute)
add_subdirectory(29_batched_gemm_bias_e_permute) add_subdirectory(29_batched_gemm_bias_e_permute)
add_subdirectory(30_grouped_convnd_fwd_bias_relu_add) add_subdirectory(30_grouped_convnd_fwd_bias_relu_add)
add_subdirectory(31_batched_gemm_gemm) add_subdirectory(31_batched_gemm_gemm)
add_subdirectory(32_batched_gemm_softmax_gemm) add_subdirectory(32_batched_gemm_scale_softmax_gemm)
add_subdirectory(33_multiple_reduce) add_subdirectory(33_multiple_reduce)
add_subdirectory(34_batchnorm) add_subdirectory(34_batchnorm)
...@@ -561,11 +561,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -561,11 +561,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatAB, FloatAB,
decltype(acc_thread_desc_k0_m_k1), decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1), decltype(a1_thread_desc_k0_m_k1),
decltype(acc_element_op), tensor_operation::element_wise::PassThrough,
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>, Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
2, 2,
n4>{acc_element_op}; n4>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy // B1 matrix blockwise copy
auto b1_blockwise_copy = auto b1_blockwise_copy =
...@@ -717,6 +717,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -717,6 +717,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
blockwise_gemm, blockwise_gemm,
acc_thread_buf, acc_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// Acc0 elementwise Op
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
// softmax // softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf; SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment