"tests/pipelines/controlnet_flux/__init__.py" did not exist on "c7ba6ba2678ca7e4e58320da8209be8883a56322"
Commit 4c105089 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 4e5d50c2 a5fd9747
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
version: 2 version: 2
updates: updates:
- package-ecosystem: "pip" # See documentation for possible values - package-ecosystem: "pip" # See documentation for possible values
directory: "/docs/.sphinx" # Location of package manifests directory: "/docs/sphinx" # Location of package manifests
open-pull-requests-limit: 10 open-pull-requests-limit: 10
schedule: schedule:
interval: "daily" interval: "daily"
...@@ -49,10 +49,10 @@ build* ...@@ -49,10 +49,10 @@ build*
install.dir* install.dir*
# documentation artifacts # documentation artifacts
build/
_build/ _build/
_images/ _images/
_static/ _static/
_templates/ _templates/
_toc.yml _toc.yml
docBin/ docBin/
_doxygen/
...@@ -11,8 +11,8 @@ build: ...@@ -11,8 +11,8 @@ build:
sphinx: sphinx:
configuration: docs/conf.py configuration: docs/conf.py
formats: [htmlzip] formats: [htmlzip, pdf, epub]
python: python:
install: install:
- requirements: docs/.sphinx/requirements.txt - requirements: docs/sphinx/requirements.txt
...@@ -19,7 +19,7 @@ def runShell(String command){ ...@@ -19,7 +19,7 @@ def runShell(String command){
def getDockerImageName(){ def getDockerImageName(){
def img def img
if (params.ROCMVERSION != "5.5" && params.ROCMVERSION != "5.6"){ if (params.ROCMVERSION != "5.6"){
if (params.COMPILER_VERSION == "") { if (params.COMPILER_VERSION == "") {
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}"
} }
...@@ -597,7 +597,7 @@ def process_results(Map conf=[:]){ ...@@ -597,7 +597,7 @@ def process_results(Map conf=[:]){
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true
0 21 * * * % ROCMVERSION=5.4.3;COMPILER_VERSION=release;COMPILER_COMMIT= 0 21 * * * % ROCMVERSION=5.5;COMPILER_VERSION=release;COMPILER_COMMIT=
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : "" 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : ""
pipeline { pipeline {
......
# Composable Kernel # Composable Kernel
## Methodology ## Methodology
Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++. Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++.
CK utilizes two concepts to achieve performance portability and code maintainability: CK utilizes two concepts to achieve performance portability and code maintainability:
...@@ -10,6 +11,7 @@ CK utilizes two concepts to achieve performance portability and code maintainabi ...@@ -10,6 +11,7 @@ CK utilizes two concepts to achieve performance portability and code maintainabi
![ALT](/docs/data/ck_component.png "CK Components") ![ALT](/docs/data/ck_component.png "CK Components")
## Code Structure ## Code Structure
Current CK library are structured into 4 layers: Current CK library are structured into 4 layers:
* "Templated Tile Operators" layer * "Templated Tile Operators" layer
* "Templated Kernel and Invoker" layer * "Templated Kernel and Invoker" layer
...@@ -24,30 +26,35 @@ Run the steps below to build documentation locally. ...@@ -24,30 +26,35 @@ Run the steps below to build documentation locally.
``` ```
cd docs cd docs
pip3 install -r .sphinx/requirements.txt pip3 install -r sphinx/requirements.txt
python3 -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html python3 -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html
``` ```
## Contributors ## Contributors
The list of developers and contributors is here: [Contributors](/CONTRIBUTORS.md) The list of developers and contributors is here: [Contributors](/CONTRIBUTORS.md)
## Citation ## Citation
If you use CK, please use following citations: If you use CK, please use following citations:
* CK paper will be freely available on arXiv soon: [Realizing Tensor Operators Using Coordinate Transformations and Tile Based Programming](???) * CK paper will be freely available on arXiv soon: [Realizing Tensor Operators Using Coordinate Transformations and Tile Based Programming](???)
* [CITATION.cff](/CITATION.cff) * [CITATION.cff](/CITATION.cff)
## License ## License
CK is released under the MIT license. [License File](/LICENSE) CK is released under the MIT license. [License File](/LICENSE)
# Build CK # Build CK
## Build docker image ## Build docker image
```bash ```bash
DOCKER_BUILDKIT=1 docker build -t ck:latest -f Dockerfile . DOCKER_BUILDKIT=1 docker build -t ck:latest -f Dockerfile .
``` ```
## Launch docker ## Launch docker
```bash ```bash
docker run \ docker run \
-it \ -it \
...@@ -60,10 +67,12 @@ ck:latest \ ...@@ -60,10 +67,12 @@ ck:latest \
``` ```
## Build CK ## Build CK
```bash ```bash
mkdir build && cd build mkdir build && cd build
# Need to specify target ID, example below is for gfx908 and gfx90a # Need to specify target ID, example below is for gfx908 and gfx90a
cmake \ cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
...@@ -74,6 +83,7 @@ cmake ...@@ -74,6 +83,7 @@ cmake
``` ```
### Build examples and tests ### Build examples and tests
```bash ```bash
make -j examples tests make -j examples tests
make test make test
...@@ -83,21 +93,25 @@ Instructions for running each individual examples are under [example](/example) ...@@ -83,21 +93,25 @@ Instructions for running each individual examples are under [example](/example)
## Build ckProfiler ## Build ckProfiler
```bash ```bash
make -j ckProfiler make -j ckProfiler
``` ```
Instructions for running ckProfiler are under [profiler](/profiler) Instructions for running ckProfiler are under [profiler](/profiler)
## Install CK ## Install CK
```bash ```bash
make install make install
``` ```
## Using CK as pre-built kernel library ## Using CK as pre-built kernel library
Instructions for using CK as a pre-built kernel library are under [client_example](/client_example) Instructions for using CK as a pre-built kernel library are under [client_example](/client_example)
## Caveat ## Caveat
### Kernel Timing and Verification ### Kernel Timing and Verification
CK's own kernel timer will warn up kernel once, and then run it multiple times CK's own kernel timer will warn up kernel once, and then run it multiple times
to get average kernel time. For some kernels that use atomic add, this will cause to get average kernel time. For some kernels that use atomic add, this will cause
output buffer to be accumulated multiple times, causing verification failure. output buffer to be accumulated multiple times, causing verification failure.
......
...@@ -4,10 +4,21 @@ ...@@ -4,10 +4,21 @@
# list see the documentation: # list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html # https://www.sphinx-doc.org/en/master/usage/configuration.html
import subprocess
from rocm_docs import ROCmDocs from rocm_docs import ROCmDocs
docs_core = ROCmDocs("Composable Kernel Documentation")
docs_core.run_doxygen() name = "Composable Kernel"
get_version = r'sed -n -e "s/^rocm_setup_version(.* \([0-9\.]\{1,\}\).*/\1/p" ../CMakeLists.txt'
version = subprocess.getoutput(get_version)
if len(version) > 0:
name = f"{name} {version}"
external_toc_path = "./sphinx/_toc.yml"
docs_core = ROCmDocs(f"{name} Documentation")
docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/docBin/xml")
docs_core.setup() docs_core.setup()
mathjax3_config = { mathjax3_config = {
......
=======
License
=======
.. include:: ../LICENSE
:literal:
# Anywhere {branch} is used, the branch name will be substituted.
# These comments will also be removed.
defaults:
numbered: False
maxdepth: 6
root: index
subtrees:
- caption: About
entries:
- file: license
rocm-docs-core==0.2.0 rocm-docs-core==0.10.3
sphinxcontrib-bibtex==2.5.0 sphinxcontrib-bibtex==2.5.0
# #
# This file is autogenerated by pip-compile with Python 3.10 # This file is autogenerated by pip-compile with Python 3.8
# by the following command: # by the following command:
# #
# pip-compile .sphinx/requirements.in # pip-compile requirements.in
# #
accessible-pygments==0.0.3 accessible-pygments==0.0.3
# via pydata-sphinx-theme # via pydata-sphinx-theme
alabaster==0.7.13 alabaster==0.7.13
# via sphinx # via sphinx
asttokens==2.2.1
# via stack-data
attrs==22.2.0
# via
# jsonschema
# jupyter-cache
babel==2.12.1 babel==2.12.1
# via # via
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
backcall==0.2.0
# via ipython
beautifulsoup4==4.11.2 beautifulsoup4==4.11.2
# via pydata-sphinx-theme # via pydata-sphinx-theme
breathe==4.34.0 breathe==4.34.0
...@@ -27,19 +19,15 @@ breathe==4.34.0 ...@@ -27,19 +19,15 @@ breathe==4.34.0
certifi==2022.12.7 certifi==2022.12.7
# via requests # via requests
cffi==1.15.1 cffi==1.15.1
# via pynacl # via
# cryptography
# pynacl
charset-normalizer==3.1.0 charset-normalizer==3.1.0
# via requests # via requests
click==8.1.3 click==8.1.3
# via # via sphinx-external-toc
# jupyter-cache cryptography==40.0.2
# sphinx-external-toc # via pyjwt
comm==0.1.2
# via ipykernel
debugpy==1.6.6
# via ipykernel
decorator==5.1.1
# via ipython
deprecated==1.2.13 deprecated==1.2.13
# via pygithub # via pygithub
docutils==0.16 docutils==0.16
...@@ -48,52 +36,26 @@ docutils==0.16 ...@@ -48,52 +36,26 @@ docutils==0.16
# myst-parser # myst-parser
# pybtex-docutils # pybtex-docutils
# pydata-sphinx-theme # pydata-sphinx-theme
# rocm-docs-core
# sphinx # sphinx
# sphinxcontrib-bibtex # sphinxcontrib-bibtex
executing==1.2.0
# via stack-data
fastjsonschema==2.16.3
# via nbformat
gitdb==4.0.10 gitdb==4.0.10
# via gitpython # via gitpython
gitpython==3.1.31 gitpython==3.1.31
# via rocm-docs-core # via rocm-docs-core
greenlet==2.0.2
# via sqlalchemy
idna==3.4 idna==3.4
# via requests # via requests
imagesize==1.4.1 imagesize==1.4.1
# via sphinx # via sphinx
importlib-metadata==6.0.0 importlib-metadata==6.0.0
# via # via
# jupyter-cache # sphinx
# myst-nb # sphinxcontrib-bibtex
ipykernel==6.21.3 importlib-resources==5.12.0
# via myst-nb # via rocm-docs-core
ipython==8.11.0
# via
# ipykernel
# myst-nb
jedi==0.18.2
# via ipython
jinja2==3.1.2 jinja2==3.1.2
# via # via
# myst-parser # myst-parser
# sphinx # sphinx
jsonschema==4.17.3
# via nbformat
jupyter-cache==0.5.0
# via myst-nb
jupyter-client==8.0.3
# via
# ipykernel
# nbclient
jupyter-core==5.3.0
# via
# ipykernel
# jupyter-client
# nbformat
latexcodec==2.0.1 latexcodec==2.0.1
# via pybtex # via pybtex
linkify-it-py==1.0.3 linkify-it-py==1.0.3
...@@ -104,54 +66,16 @@ markdown-it-py==2.2.0 ...@@ -104,54 +66,16 @@ markdown-it-py==2.2.0
# myst-parser # myst-parser
markupsafe==2.1.2 markupsafe==2.1.2
# via jinja2 # via jinja2
matplotlib-inline==0.1.6
# via
# ipykernel
# ipython
mdit-py-plugins==0.3.5 mdit-py-plugins==0.3.5
# via myst-parser # via myst-parser
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
myst-nb==0.17.1 myst-parser[linkify]==1.0.0
# via rocm-docs-core # via rocm-docs-core
myst-parser[linkify]==0.18.1
# via
# myst-nb
# rocm-docs-core
nbclient==0.5.13
# via
# jupyter-cache
# myst-nb
nbformat==5.7.3
# via
# jupyter-cache
# myst-nb
# nbclient
nest-asyncio==1.5.6
# via
# ipykernel
# nbclient
packaging==23.0 packaging==23.0
# via # via
# ipykernel
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
parso==0.8.3
# via jedi
pexpect==4.8.0
# via ipython
pickleshare==0.7.5
# via ipython
platformdirs==3.1.1
# via jupyter-core
prompt-toolkit==3.0.38
# via ipython
psutil==5.9.4
# via ipykernel
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.2
# via stack-data
pybtex==0.24.0 pybtex==0.24.0
# via # via
# pybtex-docutils # pybtex-docutils
...@@ -160,57 +84,47 @@ pybtex-docutils==1.0.2 ...@@ -160,57 +84,47 @@ pybtex-docutils==1.0.2
# via sphinxcontrib-bibtex # via sphinxcontrib-bibtex
pycparser==2.21 pycparser==2.21
# via cffi # via cffi
pydata-sphinx-theme==0.13.1 pydata-sphinx-theme==0.13.3
# via sphinx-book-theme # via
pygithub==1.57 # rocm-docs-core
# sphinx-book-theme
pygithub==1.58.2
# via rocm-docs-core # via rocm-docs-core
pygments==2.14.0 pygments==2.14.0
# via # via
# accessible-pygments # accessible-pygments
# ipython
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
pyjwt==2.6.0 pyjwt[crypto]==2.6.0
# via pygithub # via pygithub
pynacl==1.5.0 pynacl==1.5.0
# via pygithub # via pygithub
pyrsistent==0.19.3 pytz==2023.3
# via jsonschema # via babel
python-dateutil==2.8.2
# via jupyter-client
pyyaml==6.0 pyyaml==6.0
# via # via
# jupyter-cache
# myst-nb
# myst-parser # myst-parser
# pybtex # pybtex
# sphinx-external-toc # sphinx-external-toc
pyzmq==25.0.1
# via
# ipykernel
# jupyter-client
requests==2.28.2 requests==2.28.2
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==0.2.0 rocm-docs-core==0.10.3
# via -r .sphinx/requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via
# asttokens
# latexcodec # latexcodec
# pybtex # pybtex
# python-dateutil
smmap==5.0.0 smmap==5.0.0
# via gitdb # via gitdb
snowballstemmer==2.2.0 snowballstemmer==2.2.0
# via sphinx # via sphinx
soupsieve==2.4 soupsieve==2.4
# via beautifulsoup4 # via beautifulsoup4
sphinx==4.3.1 sphinx==5.3.0
# via # via
# breathe # breathe
# myst-nb
# myst-parser # myst-parser
# pydata-sphinx-theme # pydata-sphinx-theme
# rocm-docs-core # rocm-docs-core
...@@ -220,7 +134,7 @@ sphinx==4.3.1 ...@@ -220,7 +134,7 @@ sphinx==4.3.1
# sphinx-external-toc # sphinx-external-toc
# sphinx-notfound-page # sphinx-notfound-page
# sphinxcontrib-bibtex # sphinxcontrib-bibtex
sphinx-book-theme==1.0.0rc2 sphinx-book-theme==1.0.1
# via rocm-docs-core # via rocm-docs-core
sphinx-copybutton==0.5.1 sphinx-copybutton==0.5.1
# via rocm-docs-core # via rocm-docs-core
...@@ -233,7 +147,7 @@ sphinx-notfound-page==0.8.3 ...@@ -233,7 +147,7 @@ sphinx-notfound-page==0.8.3
sphinxcontrib-applehelp==1.0.4 sphinxcontrib-applehelp==1.0.4
# via sphinx # via sphinx
sphinxcontrib-bibtex==2.5.0 sphinxcontrib-bibtex==2.5.0
# via -r .sphinx/requirements.in # via -r requirements.in
sphinxcontrib-devhelp==1.0.2 sphinxcontrib-devhelp==1.0.2
# via sphinx # via sphinx
sphinxcontrib-htmlhelp==2.0.1 sphinxcontrib-htmlhelp==2.0.1
...@@ -244,40 +158,15 @@ sphinxcontrib-qthelp==1.0.3 ...@@ -244,40 +158,15 @@ sphinxcontrib-qthelp==1.0.3
# via sphinx # via sphinx
sphinxcontrib-serializinghtml==1.1.5 sphinxcontrib-serializinghtml==1.1.5
# via sphinx # via sphinx
sqlalchemy==1.4.46
# via jupyter-cache
stack-data==0.6.2
# via ipython
tabulate==0.9.0
# via jupyter-cache
tornado==6.2
# via
# ipykernel
# jupyter-client
traitlets==5.9.0
# via
# comm
# ipykernel
# ipython
# jupyter-client
# jupyter-core
# matplotlib-inline
# nbclient
# nbformat
typing-extensions==4.5.0 typing-extensions==4.5.0
# via # via pydata-sphinx-theme
# myst-nb
# myst-parser
uc-micro-py==1.0.1 uc-micro-py==1.0.1
# via linkify-it-py # via linkify-it-py
urllib3==1.26.15 urllib3==1.26.15
# via requests # via requests
wcwidth==0.2.6
# via prompt-toolkit
wrapt==1.15.0 wrapt==1.15.0
# via deprecated # via deprecated
zipp==3.15.0 zipp==3.15.0
# via importlib-metadata # via
# importlib-metadata
# The following packages are considered to be unsafe in a requirements file: # importlib-resources
# setuptools
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device:: ...@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device::
using DeviceOpInstance = DeviceOpInstanceKKNN; using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
for(int k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -385,7 +251,8 @@ int main(int argc, char* argv[]) ...@@ -385,7 +251,8 @@ int main(int argc, char* argv[])
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM, using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
ADataType, ADataType,
...@@ -393,14 +260,13 @@ int main(int argc, char* argv[]) ...@@ -393,14 +260,13 @@ int main(int argc, char* argv[])
CShuffleDataType, CShuffleDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp>;
PassThrough>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument =
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device:: ...@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device::
using DeviceOpInstance = DeviceOpInstanceKKNN; using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
for(int k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -385,7 +251,8 @@ int main(int argc, char* argv[]) ...@@ -385,7 +251,8 @@ int main(int argc, char* argv[])
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM, using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
ADataType, ADataType,
...@@ -393,14 +260,13 @@ int main(int argc, char* argv[]) ...@@ -393,14 +260,13 @@ int main(int argc, char* argv[])
CShuffleDataType, CShuffleDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp>;
PassThrough>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument =
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device:: ...@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device::
using DeviceOpInstance = DeviceOpInstanceKKN; using DeviceOpInstance = DeviceOpInstanceKKN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
for(int k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -368,7 +234,8 @@ int main(int argc, char* argv[]) ...@@ -368,7 +234,8 @@ int main(int argc, char* argv[])
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM, using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
ADataType, ADataType,
...@@ -376,14 +243,14 @@ int main(int argc, char* argv[]) ...@@ -376,14 +243,14 @@ int main(int argc, char* argv[])
CShuffleDataType, CShuffleDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp>;
PassThrough>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); auto ref_argument =
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device:: ...@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device::
using DeviceOpInstance = DeviceOpInstanceKKN; using DeviceOpInstance = DeviceOpInstanceKKN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
for(int k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -368,7 +234,8 @@ int main(int argc, char* argv[]) ...@@ -368,7 +234,8 @@ int main(int argc, char* argv[])
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM, using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
ADataType, ADataType,
...@@ -376,14 +243,14 @@ int main(int argc, char* argv[]) ...@@ -376,14 +243,14 @@ int main(int argc, char* argv[])
CShuffleDataType, CShuffleDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp>;
PassThrough>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); auto ref_argument =
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
add_example_executable(example_layernorm_blockwise layernorm_blockwise.cpp) add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp)
add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2;
constexpr int NumReduceDim = 1;
using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
PassThrough,
Rank,
NumReduceDim,
256, // BlockSize
8, // ClusterM
32, // ClusterK
1, // SliceM
8, // SliceK
1, // XYVectorDim (0=M, 1=K)
8, // SrcScalarPerVector
1, // GammaVecDim (0=M, 1=K)
8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector
8>; // OutScalarPerVector
#include "run_layernorm_example.inc"
int main() { return run_groupnorm_example<DeviceInstance>(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment