Commit f91dad8e authored by Anthony Chang's avatar Anthony Chang
Browse files

scaling

parent a7e00533
...@@ -54,7 +54,7 @@ using CPermuteNumDims_G_M_O = ...@@ -54,7 +54,7 @@ using CPermuteNumDims_G_M_O =
using AElementOp = PassThrough; using AElementOp = PassThrough;
using B0ElementOp = PassThrough; using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough; using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
...@@ -126,7 +126,7 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -126,7 +126,7 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
AccDataType, AccDataType,
AElementOp, AElementOp,
B0ElementOp, B0ElementOp,
CElementOp>; Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out // Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance = using ReferenceSoftmaxInstance =
...@@ -159,6 +159,7 @@ int main(int argc, char* argv[]) ...@@ -159,6 +159,7 @@ int main(int argc, char* argv[])
ck::index_t BatchStrideA = -1; ck::index_t BatchStrideA = -1;
ck::index_t BatchStrideB0 = -1; ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1; ck::index_t BatchStrideB1 = -1;
float alpha = 1;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
...@@ -178,7 +179,7 @@ int main(int argc, char* argv[]) ...@@ -178,7 +179,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 9) else if(argc == 11)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -190,14 +191,16 @@ int main(int argc, char* argv[]) ...@@ -190,14 +191,16 @@ int main(int argc, char* argv[])
O = std::stoi(argv[7]); O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, " printf("arg4 to 11: M, N, K, O, G0, G1\n");
"BatchStrideB0, BatchStrideB1, BatchStrideC\n"); printf("arg10: scale (alpha)\n");
exit(0); exit(0);
} }
...@@ -292,7 +295,7 @@ int main(int argc, char* argv[]) ...@@ -292,7 +295,7 @@ int main(int argc, char* argv[])
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{}; auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{}; auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
...@@ -359,7 +362,7 @@ int main(int argc, char* argv[]) ...@@ -359,7 +362,7 @@ int main(int argc, char* argv[])
auto ref_gemm0 = ReferenceGemm0Instance{}; auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument( auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, PassThrough{}); a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
......
...@@ -212,9 +212,9 @@ int main(int argc, char* argv[]) ...@@ -212,9 +212,9 @@ int main(int argc, char* argv[])
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, " printf("arg4 to 16: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
"BatchStrideB0, BatchStrideB1, BatchStrideC\n"); "BatchStrideB0, BatchStrideB1, BatchStrideC\n");
printf("arg18: alpha\n"); printf("arg17: scale (alpha)\n");
exit(0); exit(0);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment