"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "08c4d764a51e795515d49a2d8aaabdee1ba66ab7"
Unverified Commit bb4b82a9 authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

Hotfix eltiwseop (#242)



* Use vector constructor instead

* Fix typo

* Move blockSize to the MakeArgumentPointer

* Fix naming

* Fix clang format

* remove blockSize from DeviceBinaryElementwise::Argument()
Co-authored-by: default avatarrocking <chunylai@amd.com>
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 0ffe956a
......@@ -74,9 +74,7 @@ int main()
};
Tensor<ABDataType> a_m_n(f_host_tensor_descriptor2d(M, N, Stride));
Tensor<ABDataType> b_n(f_host_tensor_descriptor1d(N, 1));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, Stride));
a_m_n.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
......
......@@ -56,7 +56,7 @@ int main()
Tensor<ABDataType> a_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ABDataType> b_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ABDataType> c_m(f_host_tensor_descriptor1d(M, 1));
Tensor<CDataType> c_m(f_host_tensor_descriptor1d(M, 1));
a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
......
......@@ -5,7 +5,6 @@
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_utility.hpp"
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
......@@ -56,29 +55,29 @@ int main()
std::vector<std::size_t> nchw = {4, 16, 32, 32};
Tensor<ABDataType> a_m(nchw);
Tensor<ABDataType> b_m(nchw);
Tensor<ABDataType> c_m(nchw);
Tensor<ABDataType> a(nchw);
Tensor<ABDataType> b(nchw);
Tensor<CDataType> c(nchw);
a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
a.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace());
DeviceMem b_m_device_buf(sizeof(ABDataType) * b_m.mDesc.GetElementSpace());
DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ABDataType) * a.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(ABDataType) * b.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c.mDesc.GetElementSpace());
a_m_device_buf.ToDevice(a_m.mData.data());
b_m_device_buf.ToDevice(b_m.mData.data());
a_device_buf.ToDevice(a.mData.data());
b_device_buf.ToDevice(b.mData.data());
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(
a_m_device_buf.GetDeviceBuffer(),
b_m_device_buf.GetDeviceBuffer(),
c_m_device_buf.GetDeviceBuffer(),
ck::convert_vector_element_type<std::size_t, ck::index_t>(nchw),
ck::convert_vector_element_type<std::size_t, ck::index_t>(a_m.mDesc.GetStrides()),
ck::convert_vector_element_type<std::size_t, ck::index_t>(b_m.mDesc.GetStrides()),
ck::convert_vector_element_type<std::size_t, ck::index_t>(c_m.mDesc.GetStrides()),
a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
std::vector<ck::index_t>{nchw.begin(), nchw.end()},
std::vector<ck::index_t>{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()},
std::vector<ck::index_t>{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()},
std::vector<ck::index_t>{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
......@@ -96,17 +95,17 @@ int main()
bool pass = true;
if(do_verification)
{
c_m_device_buf.FromDevice(c_m.mData.data());
Tensor<CDataType> host_c_m(nchw);
c_device_buf.FromDevice(c.mData.data());
Tensor<CDataType> host_c(nchw);
host_elementwise4D<Tensor<ABDataType>,
Tensor<ABDataType>,
Tensor<CDataType>,
EltwiseComputeDataType,
Add>(host_c_m, a_m, b_m, nchw, Add{});
Add>(host_c, a, b, nchw, Add{});
pass &= ck::utils::check_err(
c_m.mData, host_c_m.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
pass &=
ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
}
return pass ? 0 : 1;
......
......@@ -19,8 +19,6 @@ template <typename ADataType,
index_t ScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator
{
DeviceBinaryElementwise(index_t blockSize = 256) : BaseOperator(), blockSize_(blockSize) {}
static constexpr auto I0 = Number<0>{};
template <typename Desc_M0>
......@@ -81,18 +79,18 @@ struct DeviceBinaryElementwise : public BaseOperator
const std::vector<index_t>& stride_a,
const std::vector<index_t>& stride_b,
const std::vector<index_t>& stride_c,
ElementwiseFunctor functor,
index_t blockSize)
ElementwiseFunctor functor)
: p_a_(p_a),
p_b_(p_b),
p_c_(p_c),
shape_(shape),
functor_(functor),
blockSize_(256),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, blockSize);
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, blockSize_);
}
const ADataType* p_a_;
......@@ -103,13 +101,12 @@ struct DeviceBinaryElementwise : public BaseOperator
GridDesc_M0 b_grid_desc_m0_;
GridDesc_M0 c_grid_desc_m0_;
ElementwiseFunctor functor_;
index_t blockSize_;
index_t gridSize_;
};
struct Invoker : public BaseInvoker
{
Invoker(index_t blockSize) : BaseInvoker(), blockSize_(blockSize) {}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel = kernel_binary_elementwise_1d<GridwiseBinEltwise,
......@@ -122,7 +119,7 @@ struct DeviceBinaryElementwise : public BaseOperator
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize_),
dim3(blockSize_),
dim3(arg.blockSize_),
0,
arg.p_a_,
arg.p_b_,
......@@ -140,8 +137,6 @@ struct DeviceBinaryElementwise : public BaseOperator
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
index_t blockSize_;
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -173,14 +168,10 @@ struct DeviceBinaryElementwise : public BaseOperator
stride_a,
stride_b,
stride_c,
functor,
blockSize_);
functor);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{blockSize_});
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
std::string GetTypeString() const override
{
......@@ -195,8 +186,6 @@ struct DeviceBinaryElementwise : public BaseOperator
return str.str();
}
index_t blockSize_;
};
} // namespace device
......
#pragma once
#include <vector>
namespace ck {
template <typename Src, typename Dst>
inline std::vector<Dst> convert_vector_element_type(const std::vector<Src>& inData)
{
std::vector<Dst> outData;
for(auto elem : inData)
outData.push_back(static_cast<Dst>(elem));
return (outData);
};
}; // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment