Commit 6a781e51 authored by rocking's avatar rocking
Browse files

Add broadcast div, the final step of softmax

parent b05a594e
......@@ -155,9 +155,21 @@ struct Sub_Exp
}
};
using DeviceElementwiseInstance = ck::tensor_operation::device::
struct Div
{
__host__ __device__ constexpr void
operator()(CDataType& dst, const CDataType& src1, const CDataType& src2) const
{
dst = src1 / src2;
}
};
using DeviceElementwiseSubExpInstance = ck::tensor_operation::device::
DeviceElementwise_2D<CDataType, CDataType, CDataType, Sub_Exp, 16, 16, 8, 8, 1, 1, 1, 1, 1>;
using DeviceElementwiseDivInstance = ck::tensor_operation::device::
DeviceElementwise_2D<CDataType, CDataType, CDataType, Div, 16, 16, 8, 8, 1, 1, 1, 1, 1>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
......@@ -230,10 +242,11 @@ int main(int argc, char* argv[])
Tensor<CDataType> exp_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> exp_n_sum(std::vector<std::size_t>({static_cast<std::size_t>(N)}),
std::vector<std::size_t>({1}));
Tensor<CDataType> softmax_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
const auto c_m_n_shape = ck::to_int_vector(c_m_n.mDesc.GetLengths());
const auto c_m_n_stride = ck::to_int_vector(c_m_n.mDesc.GetStrides());
const auto reduce_n_shape = ck::to_int_vector(c_n_max.mDesc.GetLengths());
const auto c_m_n_shape = ck::to_int_vector(c_m_n.mDesc.GetLengths());
const auto c_m_n_stride = ck::to_int_vector(c_m_n.mDesc.GetStrides());
const auto reduce_n_shape = ck::to_int_vector(c_n_max.mDesc.GetLengths());
const auto reduce_n_stride = ck::to_int_vector(c_n_max.mDesc.GetStrides());
size_t reduce_total_length = c_m_n.mDesc.GetElementSize() / c_n_max.mDesc.GetElementSize();
......@@ -244,6 +257,7 @@ int main(int argc, char* argv[])
std::cout << "c_n_max: " << c_n_max.mDesc << std::endl;
std::cout << "exp_m_n: " << exp_m_n.mDesc << std::endl;
std::cout << "exp_n_sum: " << exp_n_sum.mDesc << std::endl;
std::cout << "softmax_m_n: " << softmax_m_n.mDesc << std::endl;
switch(init_method)
{
......@@ -264,6 +278,7 @@ int main(int argc, char* argv[])
DeviceMem indices_device_buf(0);
DeviceMem exp_m_n_device_buf(sizeof(CDataType) * exp_m_n.mDesc.GetElementSpace());
DeviceMem exp_n_sum_device_buf(sizeof(CDataType) * exp_n_sum.mDesc.GetElementSpace());
DeviceMem softmax_m_n_device_buf(sizeof(CDataType) * softmax_m_n.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
......@@ -295,7 +310,7 @@ int main(int argc, char* argv[])
gemm_invoker.Run(gemm_argument, nrepeat);
// do reduce max
auto reduce_max = DeviceReduceMaxInstance{};
auto reduce_max = DeviceReduceMaxInstance{};
auto reduce_max_workaspace_size = reduce_max.GetWorkspaceSizeInBytes(c_m_n_shape, reduceDims);
DeviceMem reduce_max_workaspace_device_buf(reduce_max_workaspace_size);
......@@ -324,7 +339,7 @@ int main(int argc, char* argv[])
reduce_max_invoker_ptr->Run(reduce_max_argument_ptr.get(), nrepeat);
// do broadcast sub and exp
auto broadcastSubExp = DeviceElementwiseInstance{};
auto broadcastSubExp = DeviceElementwiseSubExpInstance{};
auto broadcastSubExp_argument_ptr =
broadcastSubExp.MakeArgumentPointer(c_m_n_device_buf.GetDeviceBuffer(),
c_n_max_device_buf.GetDeviceBuffer(),
......@@ -345,7 +360,7 @@ int main(int argc, char* argv[])
broadcastSubExp_invoker_ptr->Run(broadcastSubExp_argument_ptr.get(), nrepeat);
// do reduce sum - denominator of softmax
auto reduce_sum = DeviceReduceSumInstance{};
auto reduce_sum = DeviceReduceSumInstance{};
auto reduce_sum_workaspace_size = reduce_sum.GetWorkspaceSizeInBytes(c_m_n_shape, reduceDims);
DeviceMem reduce_sum_workaspace_device_buf(reduce_sum_workaspace_size);
......@@ -373,7 +388,27 @@ int main(int argc, char* argv[])
auto reduce_sum_invoker_ptr = reduce_sum.MakeInvokerPointer();
reduce_sum_invoker_ptr->Run(reduce_sum_argument_ptr.get(), nrepeat);
// TODO - Need BroadcastDiv
// do broadcast div
auto broadcastDiv = DeviceElementwiseDivInstance{};
auto broadcastDiv_argument_ptr =
broadcastDiv.MakeArgumentPointer(exp_m_n_device_buf.GetDeviceBuffer(),
exp_n_sum_device_buf.GetDeviceBuffer(),
softmax_m_n_device_buf.GetDeviceBuffer(),
{M, N},
{StrideC, 1},
{0, 1},
{StrideC, 1},
Div{});
if(!broadcastDiv.IsSupportedArgument(broadcastDiv_argument_ptr.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceElementwise_2D instance, exiting!");
};
auto broadcastDiv_invoker_ptr = broadcastDiv.MakeInvokerPointer();
broadcastDiv_invoker_ptr->Run(broadcastDiv_argument_ptr.get(), nrepeat);
// TODO = do_verification
(void)do_verification;
return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment