"cacheflow/git@developer.sourcefind.cn:SIYIXNI/vllm.git" did not exist on "bab8f3dd0ddf18ed0e28f77de57f0e55c7097aff"
Commit d179a12a authored by Astha Rai's avatar Astha Rai
Browse files

removed most of the extraneous code, testing with different dimensions

parent dab61372
...@@ -48,7 +48,7 @@ int main() ...@@ -48,7 +48,7 @@ int main()
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
const int N = 16; const int N = 120;
const int H = 32; const int H = 32;
const int W = 64; const int W = 64;
...@@ -72,7 +72,6 @@ int main() ...@@ -72,7 +72,6 @@ int main()
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()}; std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths{N, H, W, C}; std::array<ck::index_t, 4> ab_lengths{N, H, W, C};
// std::copy(nhwc.begin(), nhwc.end(), ab_lengths.begin());
std::array<ck::index_t, 4> a_strides = {C * H * W, W, 1, H * W}; std::array<ck::index_t, 4> a_strides = {C * H * W, W, 1, H * W};
std::array<ck::index_t, 4> b_strides = {H * W * C, W * C, C, 1}; std::array<ck::index_t, 4> b_strides = {H * W * C, W * C, C, 1};
......
...@@ -69,7 +69,11 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -69,7 +69,11 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
template <typename Desc_MN> template <typename Desc_MN>
static auto PadDescriptor_MN_2d(Desc_MN desc_mn, index_t gridSize, index_t blockSize, index_t num_threads_m, index_t num_threads_n) static auto PadDescriptor_MN_2d(Desc_MN desc_mn,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n)
{ {
std::ignore = blockSize; std::ignore = blockSize;
std::ignore = gridSize; std::ignore = gridSize;
...@@ -172,8 +176,8 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -172,8 +176,8 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
elementwise_op_(elementwise_op), elementwise_op_(elementwise_op),
blockSize_(256), blockSize_(256),
gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future
num_threads_m_((gridSize_*blockSize_)/8), num_threads_m_((gridSize_ * blockSize_) / 16),
num_threads_n_(8) num_threads_n_(16)
{ {
static_assert(NumDim_m > 0, ""); static_assert(NumDim_m > 0, "");
static_assert(NumDim_n > 0, ""); static_assert(NumDim_n > 0, "");
...@@ -194,15 +198,23 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -194,15 +198,23 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
in_grid_2d_desc_tuple_ = generate_tuple( in_grid_2d_desc_tuple_ = generate_tuple(
[&](auto I) { [&](auto I) {
return MakeDescriptor_MN( return MakeDescriptor_MN(lengths,
lengths, inStridesArray[I.value], gridSize_, blockSize_, num_threads_m_, num_threads_n_); inStridesArray[I.value],
gridSize_,
blockSize_,
num_threads_m_,
num_threads_n_);
}, },
Number<NumInput>{}); Number<NumInput>{});
out_grid_2d_desc_tuple_ = generate_tuple( out_grid_2d_desc_tuple_ = generate_tuple(
[&](auto I) { [&](auto I) {
return MakeDescriptor_MN( return MakeDescriptor_MN(lengths,
lengths, outStridesArray[I.value], gridSize_, blockSize_, num_threads_m_, num_threads_n_); outStridesArray[I.value],
gridSize_,
blockSize_,
num_threads_m_,
num_threads_n_);
}, },
Number<NumOutput>{}); Number<NumOutput>{});
} }
...@@ -284,10 +296,8 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -284,10 +296,8 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
return false; return false;
}; };
bool valid = true; bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) { static_for<0, NumInput, 1>{}([&](auto I) {
std::cout << "running: " << I << std::endl;
if(!IsScalarPerVectorValid(pArg->lengths_, if(!IsScalarPerVectorValid(pArg->lengths_,
pArg->inStridesArray_[I.value], pArg->inStridesArray_[I.value],
InScalarPerVectorSeq::At(I), InScalarPerVectorSeq::At(I),
...@@ -296,7 +306,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -296,7 +306,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
}); });
static_for<0, NumOutput, 1>{}([&](auto I) { static_for<0, NumOutput, 1>{}([&](auto I) {
std::cout << "running 2: " << I << std::endl;
if(!IsScalarPerVectorValid(pArg->lengths_, if(!IsScalarPerVectorValid(pArg->lengths_,
pArg->outStridesArray_[I.value], pArg->outStridesArray_[I.value],
OutScalarPerVectorSeq::At(I), OutScalarPerVectorSeq::At(I),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment