Unverified Commit a5abe1ad authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into aosewski/ggemm_splitk

parents 0b7a77c2 fd11a4a1
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "pip" # See documentation for possible values
directory: "/" # Location of package manifests
open-pull-requests-limit: 10
schedule:
interval: "daily"
...@@ -73,7 +73,7 @@ int main(int argc, char* argv[]) ...@@ -73,7 +73,7 @@ int main(int argc, char* argv[])
SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C);
SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C);
SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C);
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C); SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K);
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
......
...@@ -76,7 +76,7 @@ int main(int argc, char* argv[]) ...@@ -76,7 +76,7 @@ int main(int argc, char* argv[])
SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C);
SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C);
SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C);
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C); SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K);
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
......
...@@ -69,7 +69,7 @@ int main(int argc, char* argv[]) ...@@ -69,7 +69,7 @@ int main(int argc, char* argv[])
SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C);
SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C);
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C); SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K);
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K);
using DeviceOp = using DeviceOp =
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp" #include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = float;
using BetaDataType = ck::half_t; using BetaDataType = float;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using ComputeDataType = float;
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
......
...@@ -21,6 +21,7 @@ list(APPEND GTEST_CMAKE_CXX_FLAGS ...@@ -21,6 +21,7 @@ list(APPEND GTEST_CMAKE_CXX_FLAGS
-Wno-comma -Wno-comma
-Wno-old-style-cast -Wno-old-style-cast
-Wno-deprecated -Wno-deprecated
-Wno-unsafe-buffer-usage
) )
message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}") message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}")
......
git+https://github.com/RadeonOpenCompute/rocm-docs-core.git rocm-docs-core==0.2.0
sphinxcontrib-bibtex==2.5.0 sphinxcontrib-bibtex==2.5.0
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
# This file is autogenerated by pip-compile with Python 3.10 # This file is autogenerated by pip-compile with Python 3.10
# by the following command: # by the following command:
# #
# pip-compile requirements.in # pip-compile .sphinx/requirements.in
# #
accessible-pygments==0.0.4 accessible-pygments==0.0.3
# via pydata-sphinx-theme # via pydata-sphinx-theme
alabaster==0.7.13 alabaster==0.7.13
# via sphinx # via sphinx
...@@ -20,7 +20,7 @@ babel==2.12.1 ...@@ -20,7 +20,7 @@ babel==2.12.1
# sphinx # sphinx
backcall==0.2.0 backcall==0.2.0
# via ipython # via ipython
beautifulsoup4==4.12.0 beautifulsoup4==4.11.2
# via pydata-sphinx-theme # via pydata-sphinx-theme
breathe==4.34.0 breathe==4.34.0
# via rocm-docs-core # via rocm-docs-core
...@@ -34,7 +34,7 @@ click==8.1.3 ...@@ -34,7 +34,7 @@ click==8.1.3
# via # via
# jupyter-cache # jupyter-cache
# sphinx-external-toc # sphinx-external-toc
comm==0.1.3 comm==0.1.2
# via ipykernel # via ipykernel
debugpy==1.6.6 debugpy==1.6.6
# via ipykernel # via ipykernel
...@@ -65,13 +65,11 @@ idna==3.4 ...@@ -65,13 +65,11 @@ idna==3.4
# via requests # via requests
imagesize==1.4.1 imagesize==1.4.1
# via sphinx # via sphinx
importlib-metadata==6.1.0 importlib-metadata==6.0.0
# via # via
# jupyter-cache # jupyter-cache
# myst-nb # myst-nb
importlib-resources==5.10.4 ipykernel==6.21.3
# via rocm-docs-core
ipykernel==6.22.0
# via myst-nb # via myst-nb
ipython==8.11.0 ipython==8.11.0
# via # via
...@@ -87,7 +85,7 @@ jsonschema==4.17.3 ...@@ -87,7 +85,7 @@ jsonschema==4.17.3
# via nbformat # via nbformat
jupyter-cache==0.5.0 jupyter-cache==0.5.0
# via myst-nb # via myst-nb
jupyter-client==8.1.0 jupyter-client==8.0.3
# via # via
# ipykernel # ipykernel
# nbclient # nbclient
...@@ -124,7 +122,7 @@ nbclient==0.5.13 ...@@ -124,7 +122,7 @@ nbclient==0.5.13
# via # via
# jupyter-cache # jupyter-cache
# myst-nb # myst-nb
nbformat==5.8.0 nbformat==5.7.3
# via # via
# jupyter-cache # jupyter-cache
# myst-nb # myst-nb
...@@ -187,7 +185,7 @@ pyyaml==6.0 ...@@ -187,7 +185,7 @@ pyyaml==6.0
# myst-parser # myst-parser
# pybtex # pybtex
# sphinx-external-toc # sphinx-external-toc
pyzmq==25.0.2 pyzmq==25.0.1
# via # via
# ipykernel # ipykernel
# jupyter-client # jupyter-client
...@@ -195,8 +193,8 @@ requests==2.28.2 ...@@ -195,8 +193,8 @@ requests==2.28.2
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core @ git+https://github.com/RadeonOpenCompute/rocm-docs-core.git rocm-docs-core==0.2.0
# via -r requirements.in # via -r .sphinx/requirements.in
six==1.16.0 six==1.16.0
# via # via
# asttokens # asttokens
...@@ -235,9 +233,7 @@ sphinx-notfound-page==0.8.3 ...@@ -235,9 +233,7 @@ sphinx-notfound-page==0.8.3
sphinxcontrib-applehelp==1.0.4 sphinxcontrib-applehelp==1.0.4
# via sphinx # via sphinx
sphinxcontrib-bibtex==2.5.0 sphinxcontrib-bibtex==2.5.0
# via # via -r .sphinx/requirements.in
# -r requirements.in
# rocm-docs-core
sphinxcontrib-devhelp==1.0.2 sphinxcontrib-devhelp==1.0.2
# via sphinx # via sphinx
sphinxcontrib-htmlhelp==2.0.1 sphinxcontrib-htmlhelp==2.0.1
...@@ -248,7 +244,7 @@ sphinxcontrib-qthelp==1.0.3 ...@@ -248,7 +244,7 @@ sphinxcontrib-qthelp==1.0.3
# via sphinx # via sphinx
sphinxcontrib-serializinghtml==1.1.5 sphinxcontrib-serializinghtml==1.1.5
# via sphinx # via sphinx
sqlalchemy==1.4.47 sqlalchemy==1.4.46
# via jupyter-cache # via jupyter-cache
stack-data==0.6.2 stack-data==0.6.2
# via ipython # via ipython
......
...@@ -168,6 +168,11 @@ ...@@ -168,6 +168,11 @@
// flag to enable (1) or disable (0) the debugging output in some kernels // flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0 #define DEBUG_LOG 0
// denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0
#endif
namespace ck { namespace ck {
enum struct InMemoryDataOperationEnum enum struct InMemoryDataOperationEnum
......
...@@ -505,6 +505,15 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -505,6 +505,15 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
} }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true; return true;
} }
......
...@@ -96,7 +96,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -96,7 +96,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update // when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file // ABDataTypeAdjusted -> ABDataType throughout this file
#if defined(__gfx90a__) #if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
using ABDataTypeAdjusted = using ABDataTypeAdjusted =
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>; conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
#else #else
......
...@@ -264,6 +264,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -264,6 +264,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
} }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatA) <= TwoGB &&
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatB) <= TwoGB))
{
return false;
}
return true; return true;
} }
......
...@@ -265,7 +265,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -265,7 +265,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update // when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file // FloatABAdjusted -> FloatAB throughout this file
#if defined(__gfx90a__) #if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>; using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
#else #else
using FloatABAdjusted = FloatAB; using FloatABAdjusted = FloatAB;
......
...@@ -135,7 +135,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -135,7 +135,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update // when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file // FloatABAdjusted -> FloatAB throughout this file
#if defined(__gfx90a__) #if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>; using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
#else #else
using FloatABAdjusted = FloatAB; using FloatABAdjusted = FloatAB;
......
...@@ -25,6 +25,10 @@ void add_device_normalization_rank_5_3_swish_f16_instances( ...@@ -25,6 +25,10 @@ void add_device_normalization_rank_5_3_swish_f16_instances(
void add_device_normalization_rank_5_3_swish_f32_instances( void add_device_normalization_rank_5_3_swish_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>&); std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>&);
// [x, gamma, beta, y] = [f16, f32, f32, f16]
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Swish, 5, 3>>>&);
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
...@@ -70,6 +74,14 @@ struct DeviceOperationInstanceFactory< ...@@ -70,6 +74,14 @@ struct DeviceOperationInstanceFactory<
add_device_normalization_rank_5_3_swish_f32_instances(op_ptrs); add_device_normalization_rank_5_3_swish_f32_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F32> &&
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F16>)
{
if constexpr(Rank == 5 && NumReduceDim == 3)
{
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(op_ptrs);
}
}
return op_ptrs; return op_ptrs;
} }
......
...@@ -7,4 +7,5 @@ add_instance_library(device_normalization_instance ...@@ -7,4 +7,5 @@ add_instance_library(device_normalization_instance
device_groupnorm_f32_instance.cpp device_groupnorm_f32_instance.cpp
device_groupnorm_swish_f16_instance.cpp device_groupnorm_swish_f16_instance.cpp
device_groupnorm_swish_f32_instance.cpp device_groupnorm_swish_f32_instance.cpp
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
) )
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using Swish = ck::tensor_operation::element_wise::Swish;
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Swish, 5, 3>>>&
instances)
{
add_device_operation_instances(instances,
device_normalization_f16_f32_f32_f16_instances<Swish, 5, 3>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -69,6 +69,32 @@ using device_normalization_f32_instances = std::tuple< ...@@ -69,6 +69,32 @@ using device_normalization_f32_instances = std::tuple<
// clang-format on // clang-format on
>; >;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_normalization_f16_f32_f32_f16_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
// clang-format on
>;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment