Commit 74320196 authored by wangshaojie6's avatar wangshaojie6
Browse files

remove lower triangle gemm reference struct

parent 7ae26b79
...@@ -121,7 +121,7 @@ using DeviceGemmInstance = ...@@ -121,7 +121,7 @@ using DeviceGemmInstance =
true>; // MaskOutUpperTriangle true>; // MaskOutUpperTriangle
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemmUpperTriangleMinusInf<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType, B0DataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
...@@ -368,6 +368,11 @@ int main(int argc, char* argv[]) ...@@ -368,6 +368,11 @@ int main(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if (idx[1] < idx[2]) self(idx) = -ck::NumericLimits<float>::Infinity();
});
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker(); auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
......
...@@ -133,133 +133,6 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -133,133 +133,6 @@ struct ReferenceBatchedGemm : public device::BaseOperator
} }
}; };
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceBatchedGemmUpperTriangleMinusInf : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_g_m_k,
const Tensor<BDataType>& b_g_k_n,
Tensor<CDataType>& c_g_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_g_m_k_{a_g_m_k},
b_g_k_n_{b_g_k_n},
c_g_m_n_{c_g_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_g_m_k_;
const Tensor<BDataType>& b_g_k_n_;
Tensor<CDataType>& c_g_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceBatchedGemmUpperTriangleMinusInf::Argument;
float Run(const Argument& arg)
{
auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) {
const int K = arg.a_g_m_k_.mDesc.GetLengths()[2];
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a;
BDataType v_b;
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
AccDataType v_c;
if(((n >> 0) << 0) <= ((m >> 0) << 0))
{
arg.c_element_op_(v_c, v_acc);
}
else
{
v_c = -ck::NumericLimits<float>::Infinity();
}
arg.c_g_m_n_(g, m, n) = ck::type_convert<CDataType>(v_c);
};
make_ParallelTensorFunctor(f_gmk_gkn_gmn,
arg.c_g_m_n_.mDesc.GetLengths()[0],
arg.c_g_m_n_.mDesc.GetLengths()[1],
arg.c_g_m_n_.mDesc.GetLengths()[2])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_g_m_k,
const Tensor<BDataType>& b_g_k_n,
Tensor<CDataType>& c_g_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceBatchedGemmUpperTriangleMinusInf"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host } // namespace host
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment