"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "5b82960df840a8bd545b9a60a1b69c089e0e24f1"
Commit e330961d authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into e2e_kernellib
parents 8e862b7b 1e59eb3b
...@@ -7,6 +7,8 @@ ARG compiler_commit="" ...@@ -7,6 +7,8 @@ ARG compiler_commit=""
RUN set -xe RUN set -xe
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
RUN useradd -rm -d /home/jenkins -s /bin/bash -u 1004 jenkins
RUN useradd -rm -d /home/manitera -s /bin/bash -u 1002 manitera
# Add rocm repository # Add rocm repository
RUN apt-get update RUN apt-get update
RUN apt-get install -y wget gnupg RUN apt-get install -y wget gnupg
...@@ -37,6 +39,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -37,6 +39,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
python-dev \ python-dev \
python3-dev \ python3-dev \
python3-pip \ python3-pip \
sshpass \
software-properties-common \ software-properties-common \
rocm-dev \ rocm-dev \
rocm-device-libs \ rocm-device-libs \
......
...@@ -14,7 +14,6 @@ def show_node_info() { ...@@ -14,7 +14,6 @@ def show_node_info() {
def runShell(String command){ def runShell(String command){
def responseCode = sh returnStatus: true, script: "${command} > tmp.txt" def responseCode = sh returnStatus: true, script: "${command} > tmp.txt"
def output = readFile(file: "tmp.txt") def output = readFile(file: "tmp.txt")
echo "tmp.txt contents: $output"
return (output != "") return (output != "")
} }
...@@ -172,7 +171,7 @@ def cmake_build(Map conf=[:]){ ...@@ -172,7 +171,7 @@ def cmake_build(Map conf=[:]){
if(conf.get("build_install","") == "true") if(conf.get("build_install","") == "true")
{ {
config_targets = 'install ' + config_targets config_targets = 'install ' + config_targets
setup_args = ' -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install' + setup_args setup_args = ' -DBUILD_DEV=On -DCMAKE_INSTALL_PREFIX=../install' + setup_args
} else{ } else{
setup_args = ' -DBUILD_DEV=On' + setup_args setup_args = ' -DBUILD_DEV=On' + setup_args
} }
...@@ -427,6 +426,7 @@ def Build_CK(Map conf=[:]){ ...@@ -427,6 +426,7 @@ def Build_CK(Map conf=[:]){
def variant = env.STAGE_NAME def variant = env.STAGE_NAME
def retimage def retimage
def navi_node = 0
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
try { try {
...@@ -440,6 +440,9 @@ def Build_CK(Map conf=[:]){ ...@@ -440,6 +440,9 @@ def Build_CK(Map conf=[:]){
else{ else{
echo "GPU is OK" echo "GPU is OK"
} }
if ( runShell('grep -n "gfx1030" clinfo.log') ){
navi_node = 1
}
} }
} }
} }
...@@ -458,6 +461,9 @@ def Build_CK(Map conf=[:]){ ...@@ -458,6 +461,9 @@ def Build_CK(Map conf=[:]){
else{ else{
echo "GPU is OK" echo "GPU is OK"
} }
if ( runShell('grep -n "gfx1030" clinfo.log') ){
navi_node = 1
}
} }
} }
} }
...@@ -466,16 +472,20 @@ def Build_CK(Map conf=[:]){ ...@@ -466,16 +472,20 @@ def Build_CK(Map conf=[:]){
{ {
cmake_build(conf) cmake_build(conf)
dir("build"){ dir("build"){
//run tests and examples if (navi_node == 0 ){
sh 'make -j check' //run tests and examples on all nodes except Navi
//we only need the ckProfiler to run the performance tests, so we pack and stash it sh 'make -j check'
sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' //we only need the ckProfiler to run the performance tests, so we pack and stash it
stash "ckProfiler.tar.gz" sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler'
stash "ckProfiler.tar.gz"
}
if (params.RUN_FULL_QA){ if (params.RUN_FULL_QA){
// build deb packages // build deb packages
sh 'make -j package' sh 'make -j package'
archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb'
archiveArtifacts artifacts: 'composablekernel-tests_*.deb' archiveArtifacts artifacts: 'composablekernel-tests_*.deb'
sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb'
stash "ckprofiler_0.2.0_amd64.deb"
} }
} }
} }
...@@ -543,6 +553,8 @@ def process_results(Map conf=[:]){ ...@@ -543,6 +553,8 @@ def process_results(Map conf=[:]){
unstash "perf_splitK_gemm.log" unstash "perf_splitK_gemm.log"
unstash "perf_onnx_gemm.log" unstash "perf_onnx_gemm.log"
sh "./process_qa_data.sh" sh "./process_qa_data.sh"
unstash "ckprofiler_0.2.0_amd64.deb"
sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/"
} }
else{ else{
// unstash perf files to master // unstash perf files to master
...@@ -564,7 +576,7 @@ def process_results(Map conf=[:]){ ...@@ -564,7 +576,7 @@ def process_results(Map conf=[:]){
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true
0 21 * * * % RUN_FULL_QA=false;COMPILER_VERSION=release;COMPILER_COMMIT= 0 21 * * * % COMPILER_VERSION=release;COMPILER_COMMIT=
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : "" 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : ""
pipeline { pipeline {
...@@ -653,12 +665,28 @@ pipeline { ...@@ -653,12 +665,28 @@ pipeline {
{ {
parallel parallel
{ {
stage("Build CK and run Tests") stage("Build CK and run Tests on MI100/MI200")
{ {
agent{ label rocmnode("gfx908 || gfx90a") } agent{ label rocmnode("gfx908 || gfx90a") }
environment{ environment{
setup_args = "${params.COMPILER_VERSION == "ck-9110" ? """ -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx1030" -DCMAKE_CXX_FLAGS="-O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" """ : """ -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx1030" -DCMAKE_CXX_FLAGS="-O3 " """ }" setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" """
execute_args = "${params.COMPILER_VERSION == "ck-9110" ? """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx1030" -DCMAKE_CXX_FLAGS="-O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ : """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908,gfx90a;gfx1030" -DCMAKE_CXX_FLAGS="-O3" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ }" execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908,gfx90a" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
}
steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
}
}
stage("Build CK and run Tests on Navi")
{
when {
beforeAgent true
expression { !params.RUN_FULL_QA.toBoolean() }
}
agent{ label rocmnode("navi21") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
} }
steps{ steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
...@@ -671,7 +699,7 @@ pipeline { ...@@ -671,7 +699,7 @@ pipeline {
{ {
parallel parallel
{ {
stage("Run ckProfiler: gfx908 or gfx90a") stage("Run ckProfiler: gfx90*")
{ {
when { when {
beforeAgent true beforeAgent true
...@@ -680,7 +708,7 @@ pipeline { ...@@ -680,7 +708,7 @@ pipeline {
options { retry(2) } options { retry(2) }
agent{ label rocmnode("gfx908 || gfx90a")} agent{ label rocmnode("gfx908 || gfx90a")}
environment{ environment{
setup_args = "${params.COMPILER_VERSION == "ck-9110" ? """ -DGPU_TARGETS="gfx908;gfx90a;gfx1030" -DCMAKE_CXX_FLAGS=" -O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" -DBUILD_DEV=On """ : """ -DGPU_TARGETS="gfx908;gfx90a;gfx1030" -DCMAKE_CXX_FLAGS=" -O3 " -DBUILD_DEV=On """}" setup_args = """ -DGPU_TARGETS="gfx908;gfx90a" -DBUILD_DEV=On """
} }
steps{ steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
...@@ -695,7 +723,7 @@ pipeline { ...@@ -695,7 +723,7 @@ pipeline {
options { retry(2) } options { retry(2) }
agent{ label rocmnode("gfx90a")} agent{ label rocmnode("gfx90a")}
environment{ environment{
setup_args = "${params.COMPILER_VERSION == "ck-9110" ? """ -DGPU_TARGETS="gfx90a" -DCMAKE_CXX_FLAGS=" -O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" -DBUILD_DEV=On """ : """ -DGPU_TARGETS="gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " -DBUILD_DEV=On """}" setup_args = """ -DGPU_TARGETS="gfx90a" -DBUILD_DEV=On """
} }
steps{ steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
......
...@@ -65,7 +65,8 @@ else() ...@@ -65,7 +65,8 @@ else()
-Wuninitialized -Wuninitialized
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier
-Werror
-Wsign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
) )
......
...@@ -775,8 +775,10 @@ WARN_LOGFILE = ...@@ -775,8 +775,10 @@ WARN_LOGFILE =
# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING
# Note: If this tag is empty the current directory is searched. # Note: If this tag is empty the current directory is searched.
INPUT = ../library/include \ INPUT = ../include/ck/tensor_operation/gpu/grid \
../library/include/internal ../include/ck/tensor_operation/gpu/block \
../include/ck/tensor_operation/gpu/thread \
../library/include/ck/library/utility
# This tag can be used to specify the character encoding of the source files # This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
...@@ -845,7 +847,7 @@ FILE_PATTERNS = *.c \ ...@@ -845,7 +847,7 @@ FILE_PATTERNS = *.c \
# be searched for input files as well. # be searched for input files as well.
# The default value is: NO. # The default value is: NO.
RECURSIVE = NO RECURSIVE = YES
# The EXCLUDE tag can be used to specify files and/or directories that should be # The EXCLUDE tag can be used to specify files and/or directories that should be
# excluded from the INPUT source files. This way you can easily exclude a # excluded from the INPUT source files. This way you can easily exclude a
......
=================== *******************
API Reference Guide API Reference Guide
=================== *******************
------------ =================
Introduction Introduction
------------ =================
This document contains details of the APIs for the Composable Kernel (CK) library and introduces some of the key design This document contains details of the APIs for the Composable Kernel (CK) library and introduces some of the key design
principles that are used to write new classes that extend CK functionality. principles that are used to write new classes that extend CK functionality.
...@@ -16,8 +16,37 @@ Using CK API ...@@ -16,8 +16,37 @@ Using CK API
This section describes how to use the CK library API. This section describes how to use the CK library API.
----------------- =================
CK Datatypes CK Datatypes
=================
-----------------
DeviceMem
----------------- -----------------
[TODO] .. doxygenstruct:: DeviceMem
\ No newline at end of file
---------------------------
Kernels For Flashattention
---------------------------
The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This sections lists the classes that are
used in the CK GPU implementation of Flashattention.
**Gridwise classes**
.. doxygenstruct:: ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
**Blockwise classes**
.. doxygenstruct:: ck::ThreadGroupTensorSliceTransfer_v4r1
.. doxygenstruct:: ck::BlockwiseGemmXdlops_v2
.. doxygenstruct:: ck::BlockwiseSoftmax
**Threadwise classes**
.. doxygenstruct:: ck::ThreadwiseTensorSliceTransfer_StaticToStatic
.. bibliography::
\ No newline at end of file
...@@ -59,10 +59,13 @@ if read_the_docs_build: ...@@ -59,10 +59,13 @@ if read_the_docs_build:
# Add any Sphinx extension module names here, as strings. They can be # Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = ['sphinx.ext.mathjax', 'breathe'] extensions = ['sphinx.ext.mathjax', 'breathe', 'sphinxcontrib.bibtex']
breathe_projects = { "CK": "../docBin/xml" } breathe_projects = { "CK": "../docBin/xml" }
breathe_default_project = "CK" breathe_default_project = "CK"
bibtex_bibfiles = ['refs.bib']
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ['_templates']
......
@article{dao2022flashattention,
title={Flashattention: Fast and memory-efficient exact attention with io-awareness},
author={Dao, Tri and Fu, Daniel Y and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
journal={arXiv preprint arXiv:2205.14135},
year={2022}
}
...@@ -622,11 +622,16 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() ...@@ -622,11 +622,16 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
} }
}; };
// Blockwise gemm supporting /**
// 1. regular XDL output M2_M3_M4_M2 and transposed XDL output M2_N2_N3_N4 * @brief Blockwise gemm
// 2. decoupled input tile descriptor and mma tile descriptor in order to support both vgpr and LDS *
// source buffer * Supports
// 3. configurable k index starting position and step size after each FMA/XDL instruction * 1. regular XDL output M2_M3_M4_M2 and transposed XDL output M2_N2_N3_N4
* 2. decoupled input tile descriptor and mma tile descriptor in order to support both vgpr and LDS
* source buffer
* 3. configurable k index starting position and step size after each FMA/XDL instruction
*/
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
......
...@@ -12,6 +12,16 @@ ...@@ -12,6 +12,16 @@
namespace ck { namespace ck {
/**
* @brief Blockwise softmax
*
* @tparam BlockSize Block size
* @tparam AccDataType Accumulator data type
* @tparam ThreadMap_M_K Thread id to m_k
* @tparam ThreadClusterDesc_M_K Threadwise cluster descriptor
* @tparam ThreadSliceDesc_M_K Threadwise slices descriptor
* @tparam IgnoreNaN Flag to ignore NaN, false by default
*/
template <index_t BlockSize, template <index_t BlockSize,
typename AccDataType, typename AccDataType,
typename ThreadMap_M_K, // thread_id to m_k typename ThreadMap_M_K, // thread_id to m_k
......
...@@ -11,10 +11,15 @@ ...@@ -11,10 +11,15 @@
namespace ck { namespace ck {
// this version does following things to avoid scratch memory issue /**
// 1. Use StaticallyIndexedArray instead of C array for thread buffer * @brief Blockwise data transfer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor *
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate * This version does following things to avoid scratch memory issue
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
*
*/
template <typename ThreadGroup, template <typename ThreadGroup,
typename SrcElementwiseOperation, typename SrcElementwiseOperation,
typename DstElementwiseOperation, typename DstElementwiseOperation,
......
...@@ -166,7 +166,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -166,7 +166,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
const auto b_grid_desc_nraw_kraw = const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1)); make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
...@@ -208,39 +208,22 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -208,39 +208,22 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
} }
template <typename ELayout_> template <typename ELayout_>
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{ {
const auto e_grid_desc_m_n = [&]() { const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout_>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout_>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout_>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout_>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
} }
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& Ms, static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& Ms,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -157,7 +158,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -157,7 +158,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
const auto b_grid_desc_nraw_kraw = const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1)); make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
......
...@@ -18,6 +18,10 @@ ...@@ -18,6 +18,10 @@
namespace ck { namespace ck {
/**
* @brief Gridwise gemm + softmax + gemm fusion
*
*/
template <typename FloatAB, template <typename FloatAB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
......
...@@ -1201,7 +1201,12 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1201,7 +1201,12 @@ struct ThreadwiseTensorSliceTransfer_v4
SrcCoord src_ref_coord_; SrcCoord src_ref_coord_;
}; };
// Do NOT involve any tensor coordinates with StaticBuffer /**
* @brief Threadwise data transfer
*
* Do NOT involve any tensor coordinates with StaticBuffer
*
*/
template <typename SrcData, template <typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
......
...@@ -14,6 +14,10 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) ...@@ -14,6 +14,10 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
} }
} }
/**
* @brief Container for storing data in GPU device memory
*
*/
struct DeviceMem struct DeviceMem
{ {
DeviceMem() = delete; DeviceMem() = delete;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment