Commit bf210540 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

Merge branch 'develop' into reference_gemm_alloc

parents 36d1b311 0e54d7ae
We'd love for you to contribute to our source code!
Some helpful links:
- [Code of Conduct guidelines](https://www.contributor-covenant.org/version/2/1/code_of_conduct/code_of_conduct.txt)
- [New issue guidelines](https://github.com/rocm/composable_kernel/blob/develop/.github/ISSUE_TEMPLATE.md)
- [Submitting a pull request guidelines](https://github.com/rocm/composable_kernel/blob/develop/.github/PULL_REQUEST_TEMPLATE.md)
- [Maintainers](https://github.com/rocm/composable_kernel/blob/develop/CONTRIBUTORS.md)
- [General information](https://github.com/rocm/composable_kernel/blob/develop/README.md)
- [ROCm documentation](https://rocm.docs.amd.com/en/latest/how-to/llm-fine-tuning-optimization/optimizing-with-composable-kernel.html)
\ No newline at end of file
When creating an issue, please check if a similar issue already exists.
### When reporting a bug, please include:
- [ ] A descriptive title
- [ ] An isolated way to reproduce the behavior (preferably a docker container with a repro)
- [ ] ROCm version, clang version, Composable Kernel commit pin
- [ ] Environment variables
- [ ] The behavior you expect to see, and the behavior you actually see
### When requesting a feature, please include:
- [ ] A descriptive title
- [ ] A detailed description of the problem you are trying to solve
- [ ] An overview of the suggested solution
- [ ] Explanation why the solution is an improvement
\ No newline at end of file
## Proposed changes
Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.
## Checklist
Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.
- [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally
- [ ] I have added inline documentation which enables the maintainers with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant after this pull request
- [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
- [ ] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged
## Discussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
...@@ -64,6 +64,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -64,6 +64,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
nano \ nano \
zlib1g-dev \ zlib1g-dev \
zip \ zip \
libzstd-dev \
openssh-server \ openssh-server \
clang-format-12 \ clang-format-12 \
kmod && \ kmod && \
...@@ -93,7 +94,7 @@ RUN pip install --upgrade cmake==3.27.5 && \ ...@@ -93,7 +94,7 @@ RUN pip install --upgrade cmake==3.27.5 && \
dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \ dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \
# Install packages for processing the performance results # Install packages for processing the performance results
pip3 install --upgrade pip && \ pip3 install --upgrade pip && \
pip3 install sqlalchemy==1.4.46 pymysql pandas==2.0.3 setuptools-rust sshtunnel==0.4.0 && \ pip3 install sqlalchemy==2.0.36 pymysql pandas==2.2.3 setuptools-rust sshtunnel==0.4.0 && \
# Add render group # Add render group
groupadd -f render && \ groupadd -f render && \
# Install the new rocm-cmake version # Install the new rocm-cmake version
......
...@@ -566,11 +566,9 @@ def Build_CK(Map conf=[:]){ ...@@ -566,11 +566,9 @@ def Build_CK(Map conf=[:]){
ls -ltr ls -ltr
CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install" CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install"
cmake --build build -- -j cmake --build build -- -j
ctest --test-dir build
""" """
} }
dir("hipTensor-${params.hipTensor_branch}/build"){
sh 'ctest'
}
} }
} }
} }
......
rocm-docs-core==1.11.0 rocm-docs-core==1.12.0
sphinxcontrib-bibtex==2.6.3 sphinxcontrib-bibtex==2.6.3
...@@ -103,7 +103,7 @@ requests==2.32.3 ...@@ -103,7 +103,7 @@ requests==2.32.3
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==1.11.0 rocm-docs-core==1.12.0
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via pybtex # via pybtex
......
...@@ -78,14 +78,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD ...@@ -78,14 +78,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1 8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM 0, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1 8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN 0, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
......
...@@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[]) ...@@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension") arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon") .insert("e", "1e-5", "epsilon")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision") .insert("prec", "fp16", "precision")
...@@ -49,11 +50,14 @@ auto create_args(int argc, char* argv[]) ...@@ -49,11 +50,14 @@ auto create_args(int argc, char* argv[])
template <typename DataType> template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n"); ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(stride < 0) if(x_stride < 0)
stride = n; x_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
...@@ -68,14 +72,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -68,14 +72,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = float; using ComputeDataType = float;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n}); ck_tile::HostTensor<XScaleDataType> xscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
...@@ -116,7 +120,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -116,7 +120,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
m, m,
n, n,
stride}; x_stride,
y_stride};
auto kargs = Kernel::MakeKargs(args); auto kargs = Kernel::MakeKargs(args);
...@@ -133,7 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -133,7 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
using YDataType = ComputeDataType; using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {stride, 1}); ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier // smooth outlier
{ {
auto f = [&](auto n_) { auto f = [&](auto n_) {
...@@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.FromDevice(qy_host_dev.data()); qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>(); auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n) if(y_stride == n)
{ {
pass = ck_tile::check_err(qy_host_dev, pass = ck_tile::check_err(qy_host_dev,
qy_host_ref, qy_host_ref,
...@@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
for(int i_r = 0; i_r < m; i_r++) for(int i_r = 0; i_r < m; i_r++)
{ {
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride, std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * stride + n); qy_host_dev.begin() + i_r * y_stride +
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride, n);
qy_host_ref.begin() + i_r * stride + n); std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row, pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row, qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) + std::string("qy[") + std::to_string(i_r) +
...@@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
std::cout << "[" << data_type << "]" std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; << ", y_stride:" << y_stride << ", valid:" << (pass ? "y" : "n") << std::flush
<< std::endl;
} }
return pass; return pass;
......
...@@ -33,7 +33,8 @@ auto create_args(int argc, char* argv[]) ...@@ -33,7 +33,8 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension") arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not") .insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision") .insert("prec", "fp16", "precision")
...@@ -47,18 +48,21 @@ auto create_args(int argc, char* argv[]) ...@@ -47,18 +48,21 @@ auto create_args(int argc, char* argv[])
template <typename DataType> template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n"); ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(stride < 0) if(x_stride < 0)
stride = n; x_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
int kname = arg_parser.get_int("kname"); int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
assert(stride >= n); assert(x_stride >= n);
using TypeConfig = SmoothquantTypeConfig<DataType>; using TypeConfig = SmoothquantTypeConfig<DataType>;
...@@ -69,14 +73,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -69,14 +73,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = typename TypeConfig::ComputeDataType; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n}); ck_tile::HostTensor<XScaleDataType> xscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
...@@ -90,7 +94,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -90,7 +94,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
xscale_buf.ToDevice(xscale_host.data()); xscale_buf.ToDevice(xscale_host.data());
std::cout << "[" << data_type << "]" std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride
<< std::flush;
smoothquant_traits traits{data_type}; smoothquant_traits traits{data_type};
...@@ -100,7 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -100,7 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
m, m,
n, n,
stride}; x_stride,
y_stride};
float ave_time = smoothquant( float ave_time = smoothquant(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
...@@ -116,7 +122,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -116,7 +122,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
using YDataType = ComputeDataType; using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {stride, 1}); ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier // smooth outlier
{ {
auto f = [&](auto n_) { auto f = [&](auto n_) {
...@@ -166,7 +172,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -166,7 +172,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.FromDevice(qy_host_dev.data()); qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>(); auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n) if(y_stride == n)
{ {
pass = ck_tile::check_err(qy_host_dev, pass = ck_tile::check_err(qy_host_dev,
qy_host_ref, qy_host_ref,
...@@ -178,10 +184,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -178,10 +184,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
for(int i_r = 0; i_r < m; i_r++) for(int i_r = 0; i_r < m; i_r++)
{ {
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride, std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * stride + n); qy_host_dev.begin() + i_r * y_stride +
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride, n);
qy_host_ref.begin() + i_r * stride + n); std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row, pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row, qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) + std::string("qy[") + std::to_string(i_r) +
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <regex>
#include <optional>
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
...@@ -12,6 +14,34 @@ namespace ck { ...@@ -12,6 +14,34 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
#define GET_OBJECT_NAME_IMLP \
std::optional<std::string> GetObjectName() const override \
{ \
std::string str = __PRETTY_FUNCTION__; \
static std::regex obj_name_expr{"<std::string> (.*)::GetObjectName"}; \
std::smatch match; \
if(!std::regex_search(str, match, obj_name_expr)) \
{ \
return str; \
} \
return std::string(match[1]) + ';'; \
}
#define GET_TEMPLATE_INFO_IMPL \
std::optional<std::string> GetTemplateInfo() const override \
{ \
std::string str = __PRETTY_FUNCTION__; \
static std::regex template_expr{"\\[(.*)\\]"}; \
std::smatch match; \
if(!std::regex_search(str, match, template_expr)) \
{ \
return std::nullopt; \
} \
return std::string(match[1]); \
}
#define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
struct BaseArgument struct BaseArgument
{ {
BaseArgument() = default; BaseArgument() = default;
...@@ -48,6 +78,10 @@ struct BaseOperator ...@@ -48,6 +78,10 @@ struct BaseOperator
virtual std::string GetTypeIdName() const { return typeid(*this).name(); } virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
virtual std::optional<std::string> GetObjectName() const { return std::nullopt; }
virtual std::optional<std::string> GetTemplateInfo() const { return std::nullopt; }
virtual std::string GetTypeIdHashCode() const virtual std::string GetTypeIdHashCode() const
{ {
std::ostringstream oss; std::ostringstream oss;
......
...@@ -89,7 +89,8 @@ struct DeviceBatchedGemmV2MultiD : public BaseOperator ...@@ -89,7 +89,8 @@ struct DeviceBatchedGemmV2MultiD : public BaseOperator
index_t BatchStrideE, index_t BatchStrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0; CDEElementwiseOperation cde_element_op,
index_t KBatch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -41,12 +41,15 @@ __global__ void ...@@ -41,12 +41,15 @@ __global__ void
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t g_idx = blockIdx.z % karg.Batch; const index_t g_idx = blockIdx.z % karg.Batch;
const index_t k_idx = blockIdx.z / karg.Batch;
const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
// D pointer // D pointer
...@@ -54,8 +57,8 @@ __global__ void ...@@ -54,8 +57,8 @@ __global__ void
}); });
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset, karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + b_batch_offset, karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, karg.p_ds_grid,
karg.p_c_grid + c_batch_offset, karg.p_c_grid + c_batch_offset,
p_shared, p_shared,
...@@ -87,12 +90,15 @@ __global__ void ...@@ -87,12 +90,15 @@ __global__ void
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t g_idx = blockIdx.z % karg.Batch; const index_t g_idx = blockIdx.z % karg.Batch;
const index_t k_idx = blockIdx.z / karg.Batch;
const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
// D pointer // D pointer
...@@ -100,8 +106,8 @@ __global__ void ...@@ -100,8 +106,8 @@ __global__ void
}); });
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset, karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + b_batch_offset, karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, karg.p_ds_grid,
karg.p_c_grid + c_batch_offset, karg.p_c_grid + c_batch_offset,
p_shared_0, p_shared_0,
...@@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t Batch_, index_t Batch_,
AElementwiseOperation a_element_op_, AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_, BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_) CElementwiseOperation c_element_op_,
index_t KBatch_)
: GridwiseGemm::Argument{p_a_grid_, : GridwiseGemm::Argument{p_a_grid_,
p_b_grid_, p_b_grid_,
p_ds_grid_, p_ds_grid_,
...@@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
StrideB_, StrideB_,
StrideDs_, StrideDs_,
StrideE_, StrideE_,
1, KBatch_,
a_element_op_, a_element_op_,
b_element_op_, b_element_op_,
c_element_op_}, c_element_op_},
...@@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
arg.Print(); arg.Print();
} }
if(!GridwiseGemm::CheckValidity(arg) || arg.KBatch > 1) if(!GridwiseGemm::CheckValidity(arg))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch); std::tie(gdx, gdy, gdz) =
GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch * arg.KBatch);
float ave_time = 0; float ave_time = 0;
...@@ -387,10 +395,11 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -387,10 +395,11 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
rotating_mem.Next(); rotating_mem.Next();
// clear c mem // clear c mem
if(arg_.KBatch > 1) if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, hipGetErrorString(
0, hipMemsetAsync(arg_.p_c_grid,
arg_.M * arg_.N * sizeof(CDataType), 0,
stream_config.stream_id_)); arg.Batch * arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
}; };
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>( ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
...@@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t BatchStrideE, index_t BatchStrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op,
index_t KBatch = 1)
{ {
return Argument{static_cast<const ADataType*>(p_a), return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch, Batch,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op}; c_element_op,
KBatch};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t BatchStrideE, index_t BatchStrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op,
index_t KBatch = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch, Batch,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op,
KBatch);
} }
// polymorphic // polymorphic
......
...@@ -729,6 +729,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -729,6 +729,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return str.str(); return str.str();
} }
REGISTER_EXTRA_PRINTING_METHODS
}; };
} // namespace device } // namespace device
......
...@@ -41,7 +41,7 @@ __global__ void ...@@ -41,7 +41,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
...@@ -76,7 +76,7 @@ __global__ void ...@@ -76,7 +76,7 @@ __global__ void
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
...@@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
struct SplitKBatchOffset struct SplitKBatchOffset
{ {
__device__ SplitKBatchOffset(Argument& karg) __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{ {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead; a_k_split_offset = k_id * karg.KRead;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; a_k_split_offset = k_id * karg.KRead * karg.StrideA;
} }
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{ {
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; b_k_split_offset = k_id * karg.KRead * karg.StrideB;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
b_k_split_offset = blockIdx.z * karg.KRead; b_k_split_offset = k_id * karg.KRead;
} }
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1)) if(k_id < karg.KBatch - 1)
{ {
karg.K = karg.KRead; karg.K = karg.KRead;
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp" #include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp" #include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" #include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment