Commit a715222c authored by yuguo's avatar yuguo
Browse files

0.9.1-rocm

parent f262efc9
......@@ -38,56 +38,13 @@ limitations under the License.
#include "oneflow/core/common/hash_container.h"
#include "oneflow/core/common/meta_util.hpp"
#include "oneflow/core/common/singleton.h"
#include "oneflow/core/common/hash.h"
#include "oneflow/core/common/cpp_attribute.h"
#define CHECK_ISNULL(e) CHECK((e) == nullptr)
namespace oneflow {
inline size_t HashCombine(size_t lhs, size_t rhs) {
return lhs ^ (rhs + 0x9e3779b9 + (lhs << 6U) + (lhs >> 2U));
}
inline void HashCombine(size_t* seed, size_t hash) { *seed = HashCombine(*seed, hash); }
template<typename... T>
inline void AddHash(size_t* seed, const T&... v) {
__attribute__((__unused__)) int dummy[] = {(HashCombine(seed, std::hash<T>()(v)), 0)...};
}
template<typename T, typename... Ts>
inline size_t Hash(const T& v1, const Ts&... vn) {
size_t seed = std::hash<T>()(v1);
AddHash<Ts...>(&seed, vn...);
return seed;
}
} // namespace oneflow
namespace std {
template<typename T0, typename T1>
struct hash<std::pair<T0, T1>> {
std::size_t operator()(const std::pair<T0, T1>& p) const {
return oneflow::Hash<T0, T1>(p.first, p.second);
}
};
template<typename T>
struct hash<std::vector<T>> {
std::size_t operator()(const std::vector<T>& vec) const {
std::size_t hash_value = vec.size();
for (const auto& elem : vec) { oneflow::AddHash<T>(&hash_value, elem); }
return hash_value;
}
};
} // namespace std
namespace oneflow {
#define OF_DISALLOW_COPY(ClassName) \
ClassName(const ClassName&) = delete; \
ClassName& operator=(const ClassName&) = delete
......
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <bitset>
#include "oneflow/core/common/maybe.h"
namespace oneflow {
......@@ -37,4 +38,23 @@ static inline Maybe<int64_t> maybe_wrap_dim(int64_t dim, int64_t dim_post_expr,
if (dim < 0) dim += dim_post_expr;
return dim;
}
// align with pytorch: `aten/src/ATen/WrapDimUtilsMulti.h`
constexpr size_t dim_bitset_size = 64;
static inline Maybe<std::bitset<dim_bitset_size>> dim_list_to_bitset(
const std::vector<int32_t>& dims, int64_t ndims) {
CHECK_LE_OR_RETURN(ndims, (int64_t)dim_bitset_size)
<< Error::RuntimeError() << "Only tensors with up to " << dim_bitset_size
<< " dims are supported";
std::bitset<dim_bitset_size> seen;
for (int32_t i = 0; i < dims.size(); i++) {
size_t dim = JUST(maybe_wrap_dim(dims[i], ndims));
CHECK_OR_RETURN_ERROR(!seen[dim]) << Error::RuntimeError() << "The dim " << dim
<< " appears multiple times in the list of dims";
seen[dim] = true;
}
return seen;
}
} // namespace oneflow
......@@ -13,8 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/control/rank_info_bootstrap_server.h"
#include <thread>
#include <mutex>
#include <chrono>
#include "grpc/grpc_posix.h"
#include "oneflow/core/common/env_var/bootstrap.h"
#include "oneflow/core/control/rank_info_bootstrap_server.h"
namespace oneflow {
......@@ -29,12 +33,25 @@ std::string GetHostFromUri(const std::string& uri) {
return uri.substr(first_delimiter_pos + 1, second_delimiter_pos - first_delimiter_pos - 1);
}
int64_t rpc_bootstrap_server_sleep_seconds() {
static const int64_t rpc_bootstrap_server_sleep_seconds =
EnvInteger<ONEFLOW_RPC_BOOTSTRAP_SERVER_SLEEP_SECONDS>();
return rpc_bootstrap_server_sleep_seconds;
}
int64_t rpc_bootstrap_server_max_retry_times() {
static const int64_t rpc_bootstrap_server_max_retry_times =
EnvInteger<ONEFLOW_RPC_BOOTSTRAP_SERVER_MAX_RETRY_TIMES>();
return rpc_bootstrap_server_max_retry_times;
}
} // namespace
RankInfoBootstrapServer::RankInfoBootstrapServer(const BootstrapConf& bootstrap_conf)
: BootstrapServer(), port_(0), world_size_(bootstrap_conf.world_size()) {
Init();
int p = (bootstrap_conf.rank() == 0 ? bootstrap_conf.master_addr().port() : 0);
const int64_t rank = bootstrap_conf.rank();
int p = (rank == 0 ? bootstrap_conf.master_addr().port() : 0);
grpc::ServerBuilder server_builder;
server_builder.SetMaxMessageSize(INT_MAX);
server_builder.AddListeningPort("0.0.0.0:" + std::to_string(p), grpc::InsecureServerCredentials(),
......@@ -43,10 +60,59 @@ RankInfoBootstrapServer::RankInfoBootstrapServer(const BootstrapConf& bootstrap_
server_builder.RegisterService(grpc_service_.get());
cq_ = server_builder.AddCompletionQueue();
grpc_server_ = server_builder.BuildAndStart();
if (bootstrap_conf.rank() == 0) { CHECK_EQ(p, port()) << "Port " << p << " is unavailable"; }
if (rank == 0) { CHECK_EQ(p, port()) << "Port " << p << " is unavailable"; }
LOG(INFO) << "RankInfoBootstrapServer listening on "
<< "0.0.0.0:" + std::to_string(port());
loop_thread_ = std::thread(&RankInfoBootstrapServer::HandleRpcs, this);
if (rank == 0) {
rank2host_ = std::make_shared<std::vector<std::string>>(world_size_, "");
// NOTE: use check_thread_ to check RankInfoBootstrapServer status on rank 0
// if size of ready ranks == total ranks(world_size), means status is ok.
// otherwise, it indicates that other ranks' server have not been created successfully!
check_thread_ = std::thread(&RankInfoBootstrapServer::CheckServerStatus, this);
}
}
void RankInfoBootstrapServer::CheckServerStatus() {
bool status_ok = false;
int64_t skip_warning_times = 1;
int64_t retry_idx = 0;
// lambda function to get valid rank num of rank2host_
auto GetValidRank2HostSize = [](const std::shared_ptr<std::vector<std::string>>& rank2host) {
int64_t valid_size = 0;
for (int64_t i = 0; i < rank2host->size(); ++i) {
if (rank2host->at(i) == "") { continue; }
valid_size += 1;
}
return valid_size;
};
for (; retry_idx < rpc_bootstrap_server_max_retry_times(); ++retry_idx) {
std::this_thread::sleep_for(std::chrono::seconds(rpc_bootstrap_server_sleep_seconds()));
int64_t valid_size = 0;
{
std::lock_guard<std::mutex> lock(lock_);
valid_size = GetValidRank2HostSize(rank2host_);
}
CHECK(valid_size <= world_size_);
if (valid_size == world_size_) {
status_ok = true;
break;
} else {
if (retry_idx >= skip_warning_times) {
LOG(WARNING) << "BootstrapServer not ready, rpc server on some rank have not been created "
"successfully. Failed at "
<< retry_idx + 1 << " times, total ranks(world_size): " << world_size_
<< ", ready ranks: " << valid_size;
}
}
}
if (!status_ok) {
LOG(FATAL) << "CheckServerStatus() failed, rpc server on some rank are not ready, please check "
"whether the processes on all ranks are "
"created successfully.";
}
}
Maybe<const std::vector<std::string>&> RankInfoBootstrapServer::rank2host() const {
......@@ -59,6 +125,7 @@ void RankInfoBootstrapServer::OnLoadServer(CtrlCall<CtrlMethod::kLoadServer>* ca
CHECK_GE(rank, 0);
CHECK_LT(rank, world_size_);
if (!rank2host_) { rank2host_ = std::make_shared<std::vector<std::string>>(world_size_); }
std::lock_guard<std::mutex> lock(lock_);
rank2host_->at(rank) = GetHostFromUri(call->server_ctx().peer());
call->SendResponse();
EnqueueRequest<CtrlMethod::kLoadServer>();
......
......@@ -26,7 +26,9 @@ namespace oneflow {
class RankInfoBootstrapServer final : public BootstrapServer {
public:
OF_DISALLOW_COPY_AND_MOVE(RankInfoBootstrapServer);
~RankInfoBootstrapServer() override = default;
~RankInfoBootstrapServer() override {
if (check_thread_.joinable()) { check_thread_.join(); }
}
RankInfoBootstrapServer(const BootstrapConf& bootstrap_conf);
......@@ -35,9 +37,12 @@ class RankInfoBootstrapServer final : public BootstrapServer {
private:
void OnLoadServer(CtrlCall<CtrlMethod::kLoadServer>* call) override;
void CheckServerStatus();
int port_;
const int64_t world_size_;
std::mutex lock_;
std::thread check_thread_;
// use std::shared_ptr as std::optional
std::shared_ptr<std::vector<std::string>> rank2host_;
};
......
......@@ -16,13 +16,22 @@ limitations under the License.
#include "oneflow/core/control/rpc_client.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/job/env_desc.h"
#include "oneflow/core/common/env_var/bootstrap.h"
namespace oneflow {
namespace {
const int32_t max_retry_num = 60;
const int64_t sleep_seconds = 10;
int64_t rpc_client_max_retry_times() {
static const int64_t rpc_client_max_retry_times =
EnvInteger<ONEFLOW_RPC_CLIENT_MAX_RETRY_TIMES>();
return rpc_client_max_retry_times;
}
int64_t rpc_client_sleep_seconds() {
static const int64_t rpc_client_sleep_seconds = EnvInteger<ONEFLOW_RPC_CLIENT_SLEEP_SECONDS>();
return rpc_client_sleep_seconds;
}
#define GRPC_CHECK(x) CHECK_EQ(x.error_code(), grpc::StatusCode::OK)
......@@ -179,23 +188,28 @@ void RpcClient::LoadServer(const std::string& server_addr, CtrlService::Stub* st
void RpcClient::LoadServer(const LoadServerRequest& request, CtrlService::Stub* stub) {
int32_t retry_idx = 0;
for (; retry_idx < max_retry_num; ++retry_idx) {
int32_t skip_warning_times = 3;
for (; retry_idx < rpc_client_max_retry_times(); ++retry_idx) {
grpc::ClientContext client_ctx;
LoadServerResponse response;
grpc::Status st = stub->CallMethod<CtrlMethod::kLoadServer>(&client_ctx, request, &response);
if (st.error_code() == grpc::StatusCode::OK) {
VLOG(3) << "LoadServer " << request.addr() << " Successful at " << retry_idx << " times";
VLOG(3) << "LoadServer " << request.addr() << " Successful at " << retry_idx + 1 << " times";
break;
} else if (st.error_code() == grpc::StatusCode::UNAVAILABLE) {
LOG(WARNING) << "LoadServer " << request.addr() << " Failed at " << retry_idx << " times"
<< " error_code " << st.error_code() << " error_message " << st.error_message();
std::this_thread::sleep_for(std::chrono::seconds(sleep_seconds));
if (retry_idx >= skip_warning_times) {
LOG(WARNING) << "LoadServer " << request.addr() << " Failed at " << retry_idx + 1
<< " times"
<< " error_code: " << st.error_code()
<< " error_message: " << st.error_message();
}
std::this_thread::sleep_for(std::chrono::seconds(rpc_client_sleep_seconds()));
continue;
} else {
LOG(FATAL) << st.error_message();
}
}
CHECK_LT(retry_idx, max_retry_num);
CHECK_LT(retry_idx, rpc_client_max_retry_times());
}
CtrlService::Stub* RpcClient::GetThisStub() { return stubs_[GlobalProcessCtx::Rank()].get(); }
......
......@@ -16,7 +16,14 @@ limitations under the License.
#ifndef ONEFLOW_CORE_CUDA_ATOMIC_H_
#define ONEFLOW_CORE_CUDA_ATOMIC_H_
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIPCC__)
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#else
#include <cuda.h>
#include <cuda_runtime.h>
......@@ -25,6 +32,9 @@ limitations under the License.
#if CUDA_VERSION >= 11000
#include <cuda_bf16.h>
#endif // CUDA_VERSION >= 11000
#endif
namespace oneflow {
namespace cuda {
......@@ -34,58 +44,90 @@ namespace atomic {
namespace internal {
template<typename T, typename U>
__device__ __forceinline__ T CastCASImpl(T* address, T compare, T val) {
static_assert(sizeof(T) == sizeof(U), "");
U ret = atomicCAS(reinterpret_cast<U*>(address), *(reinterpret_cast<U*>(&compare)),
*(reinterpret_cast<U*>(&val)));
return *(reinterpret_cast<T*>(&ret));
}
struct CastCASImpl {
__device__ __forceinline__ T operator()(T* address, T compare, T val, bool* success) const {
static_assert(sizeof(T) == sizeof(U), "");
U assumed = *(reinterpret_cast<U*>(&compare));
U ret = atomicCAS(reinterpret_cast<U*>(address), assumed, *(reinterpret_cast<U*>(&val)));
*success = (ret == assumed);
return *(reinterpret_cast<T*>(&ret));
}
};
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__))
template<typename T>
struct CastCASImpl<T, unsigned short int> {
__device__ __forceinline__ T operator()(T* address, T compare, T val, bool* success) const {
static_assert(sizeof(T) == sizeof(unsigned short int), "");
size_t offset = reinterpret_cast<size_t>(address) & 0x2;
unsigned int* address_as_ui =
reinterpret_cast<unsigned int*>(reinterpret_cast<char*>(address) - offset);
unsigned int old = *address_as_ui;
unsigned int assumed = *(reinterpret_cast<unsigned short int*>(&compare));
unsigned int newval = *(reinterpret_cast<unsigned short int*>(&val));
assumed = offset ? (old & 0xffff) | (assumed << 16) : (old & 0xffff0000) | assumed;
newval = offset ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval;
unsigned int ret = atomicCAS(address_as_ui, assumed, newval);
*success = (ret == assumed);
ret = offset ? (ret >> 16) : (ret & 0xffff);
return *(reinterpret_cast<T*>(&ret));
}
};
#endif // __CUDA_ARCH__
template<typename T>
__device__ __forceinline__ typename std::enable_if<sizeof(T) == sizeof(unsigned int), T>::type
CASImpl(T* address, T compare, T val) {
return CastCASImpl<T, unsigned int>(address, compare, val);
CASImpl(T* address, T compare, T val, bool* success) {
return CastCASImpl<T, unsigned int>()(address, compare, val, success);
}
template<typename T>
__device__ __forceinline__
typename std::enable_if<sizeof(T) == sizeof(unsigned long long int), T>::type
CASImpl(T* address, T compare, T val) {
return CastCASImpl<T, unsigned long long int>(address, compare, val);
CASImpl(T* address, T compare, T val, bool* success) {
return CastCASImpl<T, unsigned long long int>()(address, compare, val, success);
}
template<typename T>
__device__ __forceinline__ typename std::enable_if<sizeof(T) == sizeof(unsigned short int), T>::type
CASImpl(T* address, T compare, T val) {
#if __CUDA_ARCH__ >= 700
return CastCASImpl<T, unsigned short int>(address, compare, val);
#else
__trap();
return 0;
#endif // __CUDA_ARCH__ >= 700
CASImpl(T* address, T compare, T val, bool* success) {
return CastCASImpl<T, unsigned short int>()(address, compare, val, success);
}
__device__ __forceinline__ int CASImpl(int* address, int compare, int val) {
return atomicCAS(address, compare, val);
__device__ __forceinline__ int CASImpl(int* address, int compare, int val, bool* success) {
int ret = atomicCAS(address, compare, val);
*success = (ret == compare);
return ret;
}
__device__ __forceinline__ unsigned int CASImpl(unsigned int* address, unsigned int compare,
unsigned int val) {
return atomicCAS(address, compare, val);
unsigned int val, bool* success) {
unsigned int ret = atomicCAS(address, compare, val);
*success = (ret == compare);
return ret;
}
__device__ __forceinline__ unsigned long long int CASImpl(unsigned long long int* address,
unsigned long long int compare,
unsigned long long int val) {
return atomicCAS(address, compare, val);
unsigned long long int val,
bool* success) {
unsigned long long int ret = atomicCAS(address, compare, val);
*success = (ret == compare);
return ret;
}
#if __CUDA_ARCH__ >= 700
__device__ __forceinline__ unsigned short int CASImpl(unsigned short int* address,
unsigned short int compare,
unsigned short int val) {
return atomicCAS(address, compare, val);
unsigned short int val, bool* success) {
unsigned short int ret = atomicCAS(address, compare, val);
*success = (ret == compare);
return ret;
}
#endif // __CUDA_ARCH__ >= 700
......@@ -99,10 +141,11 @@ template<typename T, template<typename> class BinaryOp>
__device__ __forceinline__ T AtomicCASBinaryImpl(T* address, T val) {
T old = *address;
T assumed;
bool success = false;
do {
assumed = old;
old = CASImpl(address, assumed, BinaryOp<T>()(old, val));
} while (old != assumed);
old = CASImpl(address, assumed, BinaryOp<T>()(old, val), &success);
} while (!success);
return old;
}
......@@ -156,17 +199,41 @@ __device__ __forceinline__ nv_bfloat16 AddImpl(nv_bfloat16* address, nv_bfloat16
return atomicAdd(address, val);
}
__device__ __forceinline__ nv_bfloat162 AddImpl(nv_bfloat162* address, nv_bfloat162 val) {
return atomicAdd(address, val);
}
#endif // __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ < 530
#if (__CUDA_ARCH__ < 530) && !defined(WITH_ROCM)
__device__ __forceinline__ half2 AddImpl(half2* address, half2 val) {
__trap();
TRAP();
return val;
}
#endif // __CUDA_ARCH__ < 530
#ifdef WITH_ROCM
__device__ __forceinline__ double AddImpl(double* address, double val) {
return atomicAdd(address, val);
}
__device__ __forceinline__ half AddImpl(half* address, half val) {
float address_value = __half2float(*address);
return __float2half(atomicAdd(&address_value, __half2float(val))); }
__device__ __forceinline__ half2 AddImpl(half2* address, half2 val) {
half2 res;
float2 address_value = __half22float2(*address);
res.data.x = __float2half(atomicAdd(&address_value.x, __half2float(val.data.x)));
res.data.y = __float2half(atomicAdd(&address_value.y, __half2float(val.data.y)));
return res;
}
#endif
} // namespace internal
template<typename T, typename U>
......@@ -181,7 +248,8 @@ __device__ __forceinline__ typename std::enable_if<std::is_same<T, U>::value, T>
template<typename T, typename U, typename V>
__device__ __forceinline__ T CAS(T* address, U compare, V val) {
return internal::CASImpl(address, Cast<T>(compare), Cast<T>(val));
bool success = false;
return internal::CASImpl(address, Cast<T>(compare), Cast<T>(val), &success);
}
template<typename T, typename U>
......@@ -189,6 +257,56 @@ __device__ __forceinline__ T Add(T* address, U val) {
return internal::AddImpl(address, Cast<T>(val));
}
__device__ __forceinline__ float Mul(int32_t* address, const int32_t val) {
int32_t old = *address, assumed;
do {
assumed = old;
old = atomicCAS(address, assumed, val * assumed);
} while (assumed != old);
return old;
}
__device__ __forceinline__ float Mul(uint32_t* address, const uint32_t val) {
uint32_t old = *address, assumed;
do {
assumed = old;
old = atomicCAS(address, assumed, val * assumed);
} while (assumed != old);
return old;
}
__device__ __forceinline__ float Mul(uint64_t* address, const uint64_t val) {
static_assert(sizeof(uint64_t) == sizeof(unsigned long long int), "");
unsigned long long int old = *reinterpret_cast<unsigned long long int*>(address), assumed;
do {
assumed = old;
old = atomicCAS(reinterpret_cast<unsigned long long int*>(address), assumed,
static_cast<unsigned long long int>(val) * assumed);
} while (assumed != old);
return old;
}
__device__ __forceinline__ float Mul(float* address, const float val) {
int32_t* address_as_int = reinterpret_cast<int32_t*>(address);
int32_t old = *address_as_int, assumed;
do {
assumed = old;
old = atomicCAS(address_as_int, assumed, __float_as_int(val * __int_as_float(assumed)));
} while (assumed != old);
return __int_as_float(old);
}
__device__ __forceinline__ float Mul(double* address, const double val) {
unsigned long long int* address_as_ull = reinterpret_cast<unsigned long long int*>(address);
unsigned long long int old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val * __longlong_as_double(assumed)));
} while (assumed != old);
return __longlong_as_double(old);
}
__device__ __forceinline__ float Max(float* address, const float val) {
int* address_as_i = (int*)address;
int old = *address_as_i;
......
......@@ -16,7 +16,13 @@ limitations under the License.
#ifndef ONEFLOW_CORE_CUDA_ELEMENTWISE_H_
#define ONEFLOW_CORE_CUDA_ELEMENTWISE_H_
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#else
#include <cuda_runtime.h>
#endif
#include "oneflow/core/ep/include/gpu_macro.h"
#include <cstdint>
#include <algorithm>
#include <type_traits>
......@@ -30,25 +36,25 @@ namespace elementwise {
constexpr int kBlockSize = 256;
constexpr int kNumWaves = 32;
inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) {
inline GPU(Error_t) GetNumBlocks(int64_t n, int* num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GPU(GetDevice)(&dev);
if (err != GPU(Success)) { return err; }
}
int sm_count;
{
cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GPU(DeviceGetAttribute)(&sm_count, GPUMultiProcessorCount, dev);
if (err != GPU(Success)) { return err; }
}
int tpm;
{
cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GPU(DeviceGetAttribute)(&tpm, GPUMaxThreadsPerMultiProcessor, dev);
if (err != GPU(Success)) { return err; }
}
*num_blocks = std::max<int>(1, std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * tpm / kBlockSize * kNumWaves));
return cudaSuccess;
return GPU(Success);
}
template<typename T, int pack_size>
......@@ -113,24 +119,24 @@ class HasApply2 {
template<int pack_size, typename FunctorT, typename R, typename... IN>
__device__ typename std::enable_if<HasApply2<FunctorT>::value == true && pack_size % 2 == 0,
Packed<R, pack_size>>::type
ApplyPack(const FunctorT& functor, const IN... in[pack_size]) {
ApplyPack(const FunctorT& functor, const Packed<IN, pack_size>... in) {
Packed<R, pack_size> ret;
#pragma unroll
for (int j = 0; j < pack_size; j += 2) { functor.Apply2(ret.elem + j, (in + j)...); }
for (int j = 0; j < pack_size; j += 2) { functor.Apply2(ret.elem + j, (in.elem + j)...); }
return ret;
}
template<int pack_size, typename FunctorT, typename R, typename... IN>
__device__ typename std::enable_if<HasApply2<FunctorT>::value == false || pack_size % 2 != 0,
Packed<R, pack_size>>::type
ApplyPack(const FunctorT& functor, const IN... in[pack_size]) {
ApplyPack(const FunctorT& functor, const Packed<IN, pack_size>... in) {
Packed<R, pack_size> ret;
#pragma unroll
for (int j = 0; j < pack_size; ++j) { ret.elem[j] = functor((in[j])...); }
for (int j = 0; j < pack_size; ++j) { ret.elem[j] = functor((in.elem[j])...); }
return ret;
}
template<int pack_size, bool tail, typename FactoryT, typename R, typename... IN>
template<int pack_size, typename FactoryT, typename R, typename... IN>
__global__ void __launch_bounds__(kBlockSize)
ApplyGeneric(FactoryT factory, int64_t n_pack, Packed<R, pack_size>* pack_r,
const Packed<IN, pack_size>*... pack_in, int64_t n_tail, R* tail_r,
......@@ -138,9 +144,9 @@ __global__ void __launch_bounds__(kBlockSize)
auto functor = factory();
const int global_tid = blockIdx.x * kBlockSize + threadIdx.x;
for (int64_t i = global_tid; i < n_pack; i += blockDim.x * gridDim.x) {
pack_r[i] = ApplyPack<pack_size, decltype(functor), R, IN...>(functor, (pack_in[i].elem)...);
pack_r[i] = ApplyPack<pack_size, decltype(functor), R, IN...>(functor, (pack_in[i])...);
}
if (tail && global_tid < n_tail) { tail_r[global_tid] = functor((tail_in[global_tid])...); }
if (global_tid < n_tail) { tail_r[global_tid] = functor((tail_in[global_tid])...); }
}
template<typename FunctorT>
......@@ -153,41 +159,39 @@ struct SimpleFactory {
};
template<size_t pack_size>
bool IsAligendForPack() {
bool IsAlignedForPack() {
return true;
}
template<size_t pack_size, typename T, typename... Args>
bool IsAligendForPack(const T* ptr, const Args*... others) {
bool IsAlignedForPack(const T* ptr, const Args*... others) {
return reinterpret_cast<uintptr_t>(ptr) % sizeof(Pack<T, pack_size>) == 0
&& IsAligendForPack<pack_size, Args...>(others...);
&& IsAlignedForPack<pack_size, Args...>(others...);
}
template<size_t pack_size, typename FactoryT, typename R, typename... IN>
cudaError_t LaunchKernel(FactoryT factory, int64_t n, R* r, const IN*... in, cudaStream_t stream) {
GPU(Error_t) LaunchKernel(FactoryT factory, int64_t n, R* r, const IN*... in, GPU(Stream_t) stream) {
const int64_t n_pack = n / pack_size;
const int64_t tail_offset = n_pack * pack_size;
const int64_t n_tail = n - tail_offset;
int num_blocks;
{
cudaError_t err = GetNumBlocks(n_pack, &num_blocks);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GetNumBlocks(n_pack, &num_blocks);
if (err != GPU(Success)) { return err; }
}
auto func = n_tail > 0 ? ApplyGeneric<pack_size, true, FactoryT, R, IN...>
: ApplyGeneric<pack_size, false, FactoryT, R, IN...>;
func<<<num_blocks, kBlockSize, 0, stream>>>(
ApplyGeneric<pack_size, FactoryT, R, IN...><<<num_blocks, kBlockSize, 0, stream>>>(
factory, n_pack, reinterpret_cast<Packed<R, pack_size>*>(r),
(reinterpret_cast<const Packed<IN, pack_size>*>(in))..., n_tail, r + tail_offset,
(in + tail_offset)...);
return cudaPeekAtLastError();
return GPU(PeekAtLastError)();
}
template<typename FactoryT, typename R, typename... IN>
struct GenericLauncher {
static cudaError_t Launch(FactoryT factory, int64_t n, R* r, const IN*... in,
cudaStream_t stream) {
static GPU(Error_t) Launch(FactoryT factory, int64_t n, R* r, const IN*... in,
GPU(Stream_t) stream) {
constexpr int max_pack_size = PackSize<R, IN...>();
if (IsAligendForPack<max_pack_size, R, IN...>(r, in...)) {
if (IsAlignedForPack<max_pack_size, R, IN...>(r, in...)) {
return LaunchKernel<max_pack_size, FactoryT, R, IN...>(factory, n, r, in..., stream);
} else {
return LaunchKernel<1, FactoryT, R, IN...>(factory, n, r, in..., stream);
......@@ -196,37 +200,37 @@ struct GenericLauncher {
};
template<typename FactoryT, typename R, typename A>
inline cudaError_t UnaryWithFactory(FactoryT factory, int64_t n, R* r, const A* a,
cudaStream_t stream) {
inline GPU(Error_t) UnaryWithFactory(FactoryT factory, int64_t n, R* r, const A* a,
GPU(Stream_t) stream) {
return GenericLauncher<FactoryT, R, A>::Launch(factory, n, r, a, stream);
}
template<typename FunctorT, typename R, typename A>
inline cudaError_t Unary(FunctorT functor, int64_t n, R* r, const A* a, cudaStream_t stream) {
inline GPU(Error_t) Unary(FunctorT functor, int64_t n, R* r, const A* a, GPU(Stream_t) stream) {
return UnaryWithFactory(SimpleFactory<FunctorT>(functor), n, r, a, stream);
}
template<typename FactoryT, typename R, typename A, typename B>
inline cudaError_t BinaryWithFactory(FactoryT factory, int64_t n, R* r, const A* a, const B* b,
cudaStream_t stream) {
inline GPU(Error_t) BinaryWithFactory(FactoryT factory, int64_t n, R* r, const A* a, const B* b,
GPU(Stream_t) stream) {
return GenericLauncher<FactoryT, R, A, B>::Launch(factory, n, r, a, b, stream);
}
template<typename FunctorT, typename R, typename A, typename B>
inline cudaError_t Binary(FunctorT functor, int64_t n, R* r, const A* a, const B* b,
cudaStream_t stream) {
inline GPU(Error_t) Binary(FunctorT functor, int64_t n, R* r, const A* a, const B* b,
GPU(Stream_t) stream) {
return BinaryWithFactory(SimpleFactory<FunctorT>(functor), n, r, a, b, stream);
}
template<typename FactoryT, typename R, typename A, typename B, typename C>
inline cudaError_t TernaryWithFactory(FactoryT factory, int64_t n, R* r, const A* a, const B* b,
const C* c, cudaStream_t stream) {
inline GPU(Error_t) TernaryWithFactory(FactoryT factory, int64_t n, R* r, const A* a, const B* b,
const C* c, GPU(Stream_t) stream) {
return GenericLauncher<FactoryT, R, A, B, C>::Launch(factory, n, r, a, b, c, stream);
}
template<typename FunctorT, typename R, typename A, typename B, typename C>
inline cudaError_t Ternary(FunctorT functor, int64_t n, R* r, const A* a, const B* b, const C* c,
cudaStream_t stream) {
inline GPU(Error_t) Ternary(FunctorT functor, int64_t n, R* r, const A* a, const B* b, const C* c,
GPU(Stream_t) stream) {
return TernaryWithFactory(SimpleFactory<FunctorT>(functor), n, r, a, b, c, stream);
}
......
......@@ -17,8 +17,14 @@ limitations under the License.
#ifndef ONEFLOW_CORE_CUDA_LAYER_NORM_H_
#define ONEFLOW_CORE_CUDA_LAYER_NORM_H_
#ifdef WITH_ROCM
#include "hip/hip_runtime.h"
#include <hipcub/hipcub.hpp>
#else
#include <cub/cub.cuh>
#include <math_constants.h>
#endif
#include <assert.h>
namespace oneflow {
......@@ -27,7 +33,11 @@ namespace cuda {
namespace layer_norm {
#ifdef WITH_ROCM
constexpr int kWarpSize = 64;
#else
constexpr int kWarpSize = 32;
#endif
template<typename T>
struct SumOp {
......@@ -42,14 +52,22 @@ struct MaxOp {
template<template<typename> class ReductionOp, typename T, int thread_group_width = kWarpSize>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
#ifdef WITH_ROCM
val = ReductionOp<T>()(val, __shfl_xor(val, mask, thread_group_width));
#else
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask, thread_group_width));
#endif
}
return val;
}
template<template<typename> class ReductionOp, typename T, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
#ifdef WITH_ROCM
typedef hipcub::BlockReduce<T, block_size> BlockReduce;
#else
typedef cub::BlockReduce<T, block_size> BlockReduce;
#endif
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
......@@ -93,26 +111,26 @@ __inline__ __device__ double Rsqrt<double>(double x) {
}
template<class Func>
inline cudaError_t GetNumBlocks(Func func, int64_t block_size, size_t dynamic_smem_size,
inline GPU(Error_t) GetNumBlocks(Func func, int64_t block_size, size_t dynamic_smem_size,
int64_t max_blocks, int64_t waves, int* num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GPU(GetDevice)(&dev);
if (err != GPU(Success)) { return err; }
}
int sm_count;
{
cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GPU(DeviceGetAttribute)(&sm_count, GPUMultiProcessorCount, dev);
if (err != GPU(Success)) { return err; }
}
int max_active_blocks;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, func,
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(&max_active_blocks, func,
block_size, dynamic_smem_size);
}
*num_blocks =
std::max<int>(1, std::min<int64_t>(max_blocks, sm_count * max_active_blocks * waves));
return cudaSuccess;
return GPU(Success);
}
template<typename T>
......@@ -132,6 +150,34 @@ struct DefaultComputeType<nv_bfloat16> {
};
#endif // CUDA_VERSION >= 11000
template<typename T>
class HasCanPackAs {
typedef char one;
struct two {
char x[2];
};
template<typename C>
static one test(decltype(&C::CanPackAs));
template<typename C>
static two test(...);
public:
enum { value = sizeof(test<T>(0)) == sizeof(char) };
};
template<typename T>
typename std::enable_if<HasCanPackAs<T>::value == true, bool>::type CanPackAs(T t,
size_t pack_size) {
return t.CanPackAs(pack_size);
}
template<typename T>
typename std::enable_if<HasCanPackAs<T>::value == false, bool>::type CanPackAs(T t,
size_t pack_size) {
return true;
}
template<typename T, int N>
struct GetPackType {
using type = typename std::aligned_storage<N * sizeof(T), N * sizeof(T)>::type;
......@@ -152,6 +198,7 @@ union Pack {
template<typename SRC, typename DST>
struct DirectLoad {
using LoadType = DST;
DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {}
template<int N>
__device__ void load(DST* dst, int64_t row, int64_t col) const {
......@@ -210,9 +257,15 @@ __inline__ __device__ void WelfordWarpReduce(T thread_mean, T thread_m2, T threa
*m2 = thread_m2;
*count = thread_count;
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
#ifdef WITH_ROCM
T b_mean = __shfl_down(*mean, mask, thread_group_width);
T b_m2 = __shfl_down(*m2, mask, thread_group_width);
T b_count = __shfl_down(*count, mask, thread_group_width);
#else
T b_mean = __shfl_down_sync(0xffffffff, *mean, mask, thread_group_width);
T b_m2 = __shfl_down_sync(0xffffffff, *m2, mask, thread_group_width);
T b_count = __shfl_down_sync(0xffffffff, *count, mask, thread_group_width);
#endif
WelfordCombine(b_mean, b_m2, b_count, mean, m2, count);
}
}
......@@ -221,9 +274,16 @@ template<typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpAllReduce(T thread_mean, T thread_m2, T thread_count, T* mean,
T* m2, T* count) {
WelfordWarpReduce<T, thread_group_width>(thread_mean, thread_m2, thread_count, mean, m2, count);
#ifdef WITH_ROCM
*mean = __shfl(*mean, 0, thread_group_width);
*m2 = __shfl(*m2, 0, thread_group_width);
*count = __shfl(*count, 0, thread_group_width);
#else
*mean = __shfl_sync(0xffffffff, *mean, 0, thread_group_width);
*m2 = __shfl_sync(0xffffffff, *m2, 0, thread_group_width);
*count = __shfl_sync(0xffffffff, *count, 0, thread_group_width);
#endif
}
template<typename T>
......@@ -258,7 +318,11 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
warp_m2 = static_cast<T>(0);
warp_count = static_cast<T>(0);
}
__syncwarp();
#ifdef WITH_ROCM
__syncthreads();
#else
__syncwarp();
#endif
T block_mean = 0;
T block_m2 = 0;
T block_count = 0;
......@@ -275,17 +339,21 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
*result_count = count_result_broadcast;
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access, bool padding>
template<typename LOAD, typename STORE, typename ComputeType, int pack_size,
int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,
int rows_per_access, bool padding>
__global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
static_assert(cols_per_thread % pack_size == 0, "");
using LoadType = typename LOAD::LoadType;
static_assert(max_cols_per_thread % pack_size == 0, "");
static_assert(min_cols_per_thread % pack_size == 0, "");
static_assert(thread_group_width <= kWarpSize, "");
static_assert(kWarpSize % thread_group_width == 0, "");
constexpr int num_packs = cols_per_thread / pack_size;
assert(cols <= cols_per_thread * thread_group_width);
ComputeType buf[rows_per_access][cols_per_thread];
constexpr int max_num_packs = max_cols_per_thread / pack_size;
constexpr int min_num_packs = min_cols_per_thread / pack_size;
assert(cols <= max_cols_per_thread * thread_group_width);
ComputeType buf[rows_per_access][max_cols_per_thread];
const int64_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;
const int64_t num_global_thread_group = gridDim.x * blockDim.y;
const int64_t lane_id = threadIdx.x;
......@@ -301,13 +369,27 @@ __global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, co
thread_count[row_id] = 0;
ComputeType* row_buf = buf[row_id];
#pragma unroll
for (int pack_id = 0; pack_id < num_packs; ++pack_id) {
for (int pack_id = 0; pack_id < min_num_packs; ++pack_id) {
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
const int pack_offset = pack_id * pack_size;
LoadType pack[pack_size];
load.template load<pack_size>(pack, row + row_id, col);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
row_buf[pack_offset + i] = static_cast<ComputeType>(pack[i]);
WelfordCombine(row_buf[pack_offset + i], thread_mean + row_id, thread_m2 + row_id,
thread_count + row_id);
}
}
for (int pack_id = min_num_packs; pack_id < max_num_packs; ++pack_id) {
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
const int pack_offset = pack_id * pack_size;
if (!padding || col < cols) {
load.template load<pack_size>(row_buf + pack_offset, row + row_id, col);
LoadType pack[pack_size];
load.template load<pack_size>(pack, row + row_id, col);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
row_buf[pack_offset + i] = static_cast<ComputeType>(pack[i]);
WelfordCombine(row_buf[pack_offset + i], thread_mean + row_id, thread_m2 + row_id,
thread_count + row_id);
}
......@@ -336,11 +418,16 @@ __global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, co
inv_variance[global_row_id] = row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
for (int i = 0; i < max_cols_per_thread; ++i) {
row_buf[i] = (row_buf[i] - row_mean) * row_inv_var;
}
#pragma unroll
for (int i = 0; i < num_packs; ++i) {
for (int i = 0; i < min_num_packs; ++i) {
const int col = (i * thread_group_width + lane_id) * pack_size;
store.template store<pack_size>(row_buf + i * pack_size, global_row_id, col);
}
#pragma unroll
for (int i = min_num_packs; i < max_num_packs; ++i) {
const int col = (i * thread_group_width + lane_id) * pack_size;
if (!padding || col < cols) {
store.template store<pack_size>(row_buf + i * pack_size, global_row_id, col);
......@@ -350,9 +437,10 @@ __global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, co
}
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access, bool padding>
inline cudaError_t LaunchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE store,
template<typename LOAD, typename STORE, typename ComputeType, int pack_size,
int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,
int rows_per_access, bool padding>
inline GPU(Error_t) LaunchLayerNormWarpImpl(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
......@@ -365,171 +453,129 @@ inline cudaError_t LaunchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE
(rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;
int grid_dim_x;
{
cudaError_t err =
GetNumBlocks(LayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, padding>,
block_size, 0, num_blocks, waves, &grid_dim_x);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GetNumBlocks(
LayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread,
min_cols_per_thread, thread_group_width, rows_per_access, padding>,
block_size, 0, num_blocks, waves, &grid_dim_x);
if (err != GPU(Success)) { return err; }
}
LayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread, thread_group_width,
rows_per_access, padding>
LayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread, min_cols_per_thread,
thread_group_width, rows_per_access, padding>
<<<grid_dim_x, block_dim, 0, stream>>>(load, store, rows, cols, epsilon, mean, inv_variance);
return cudaPeekAtLastError();
return GPU(PeekAtLastError)();
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access>
inline cudaError_t DispatchLayerNormWarpImplPadding(cudaStream_t stream, LOAD load, STORE store,
template<typename LOAD, typename STORE, typename ComputeType, int pack_size,
int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,
int rows_per_access>
inline GPU(Error_t) DispatchLayerNormWarpImplPadding(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
if (cols == cols_per_thread * thread_group_width) {
return LaunchLayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, false>(
if (cols == max_cols_per_thread * thread_group_width) {
// when not padding, min_cols_per_thread must equals to max_cols_per_thread, pass
// max_cols_per_thread as min_cols_per_thread and max_cols_per_thread param.
return LaunchLayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread,
max_cols_per_thread, thread_group_width, rows_per_access, false>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else {
return LaunchLayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, true>(
return LaunchLayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread,
min_cols_per_thread, thread_group_width, rows_per_access, true>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size>
typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchLayerNormWarpImplCols(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormWarpImplCols(
GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean, ComputeType* inv_variance) {
if (cols <= 0) { return cudaErrorInvalidValue; }
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 2>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} else { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 1>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
if (cols <= 0) { return GPU(ErrorInvalidValue); }
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, 0, \
thread_group_width, 2>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} else { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, 0, \
thread_group_width, 1>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, kWarpSize, \
1>(stream, load, store, rows, cols, epsilon, mean, \
inv_variance); \
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, max_col, min_col, \
kWarpSize, 1>(stream, load, store, rows, cols, \
epsilon, mean, inv_variance); \
}
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(32)
DEFINE_ONE_ELIF(2, 1)
DEFINE_ONE_ELIF(4, 2)
DEFINE_ONE_ELIF(8, 4)
DEFINE_ONE_ELIF(12, 8)
DEFINE_ONE_ELIF(16, 12)
DEFINE_ONE_ELIF(20, 16)
DEFINE_ONE_ELIF(24, 20)
DEFINE_ONE_ELIF(28, 24)
DEFINE_ONE_ELIF(32, 28)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
return GPU(ErrorInvalidValue);
}
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size>
typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchLayerNormWarpImplCols(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchLayerNormWarpImplCols(
GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean, ComputeType* inv_variance) {
if (cols <= 0) { return cudaErrorInvalidValue; }
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 2>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} else { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 1>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, kWarpSize, \
1>(stream, load, store, rows, cols, epsilon, mean, \
inv_variance); \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
if (cols <= 0) { return GPU(ErrorInvalidValue); }
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, 0, \
thread_group_width, 2>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} else { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, 0, \
thread_group_width, 1>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size>
typename std::enable_if<pack_size == 4, cudaError_t>::type DispatchLayerNormWarpImplCols(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean, ComputeType* inv_variance) {
if (cols <= 0) { return cudaErrorInvalidValue; }
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 2>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} else { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 1>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, kWarpSize, \
1>(stream, load, store, rows, cols, epsilon, mean, \
inv_variance); \
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, max_col, min_col, \
kWarpSize, 1>(stream, load, store, rows, cols, \
epsilon, mean, inv_variance); \
}
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(32)
DEFINE_ONE_ELIF(4, 2)
DEFINE_ONE_ELIF(8, 4)
DEFINE_ONE_ELIF(12, 8)
DEFINE_ONE_ELIF(16, 12)
DEFINE_ONE_ELIF(20, 16)
DEFINE_ONE_ELIF(24, 20)
DEFINE_ONE_ELIF(28, 24)
DEFINE_ONE_ELIF(32, 28)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
return GPU(ErrorInvalidValue);
}
}
template<typename LOAD, typename STORE, typename ComputeType>
struct DispatchLayerNormWarpImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
GPU(Error_t) operator()(GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
if (cols % 4 == 0) {
return DispatchLayerNormWarpImplCols<LOAD, STORE, ComputeType, 4>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else if (cols % 2 == 0) {
if (cols % 2 == 0 && CanPackAs<LOAD>(load, 2) && CanPackAs<STORE>(store, 2)) {
return DispatchLayerNormWarpImplCols<LOAD, STORE, ComputeType, 2>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else {
......@@ -540,7 +586,7 @@ struct DispatchLayerNormWarpImplPackSize {
};
template<typename LOAD, typename STORE, typename ComputeType>
inline cudaError_t DispatchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE store,
inline GPU(Error_t) DispatchLayerNormWarpImpl(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
......@@ -552,8 +598,9 @@ template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int
__global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
using LoadType = typename LOAD::LoadType;
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto* buf = reinterpret_cast<ComputeType*>(shared_buf);
auto* buf = reinterpret_cast<LoadType*>(shared_buf);
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = static_cast<int>(cols) / pack_size;
......@@ -562,12 +609,12 @@ __global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t row
ComputeType thread_m2 = 0;
ComputeType thread_count = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
LoadType pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
buf[i * num_packs + pack_id] = pack[i];
WelfordCombine(pack[i], &thread_mean, &thread_m2, &thread_count);
WelfordCombine(static_cast<ComputeType>(pack[i]), &thread_mean, &thread_m2, &thread_count);
}
}
ComputeType row_mean = 0;
......@@ -585,7 +632,7 @@ __global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t row
ComputeType pack[pack_size];
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack[i] = (buf[i * num_packs + pack_id] - row_mean) * row_inv_var;
pack[i] = (static_cast<ComputeType>(buf[i * num_packs + pack_id]) - row_mean) * row_inv_var;
}
store.template store<pack_size>(pack, row, pack_id * pack_size);
}
......@@ -593,88 +640,152 @@ __global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t row
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
inline cudaError_t LaunchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store,
inline GPU(Error_t) LaunchLayerNormBlockSMemImpl(GPU(Stream_t) stream, LOAD load, STORE store,
int smem, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err =
GPU(Error_t) err =
GetNumBlocks(LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size>,
block_size, smem, rows, waves, &grid_dim_x);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, smem, stream>>>(load, store, rows, cols, epsilon, mean,
inv_variance);
return cudaPeekAtLastError();
return GPU(PeekAtLastError)();
}
template<typename Func>
GPU(Error_t) MaximizeDynamicSharedMemorySize(Func func, const int max_smem_size) {
GPU(FuncAttributes) attr{};
#ifdef WITH_ROCM
GPU(Error_t) err = GPU(FuncGetAttributes)(&attr, (const void*)func);
#else
GPU(Error_t) err = GPU(FuncGetAttributes)(&attr, func);
#endif
if (err != GPU(Success)) { return err; }
constexpr int reserved_smem = 1024; // 1K
#ifdef WITH_ROCM
return GPU(FuncSetAttribute)((const void*)func, GPU(FuncAttributeMaxDynamicSharedMemorySize),
max_smem_size - attr.sharedSizeBytes - reserved_smem);
#else
return GPU(FuncSetAttribute)(func, GPU(FuncAttributeMaxDynamicSharedMemorySize),
max_smem_size - attr.sharedSizeBytes - reserved_smem);
#endif
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size>
inline cudaError_t TryDispatchLayerNormBlockSMemImplBlockSize(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
inline GPU(Error_t) TryDispatchLayerNormBlockSMemImplBlockSize(
GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean, ComputeType* inv_variance, bool* success) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
const size_t smem = cols * sizeof(ComputeType);
int max_active_blocks_conf_1;
int dev = 0;
{
GPU(Error_t) err = GPU(GetDevice)(&dev);
if (err != GPU(Success)) { return err; }
}
int sm_count = 0;
{
GPU(Error_t) err = GPU(DeviceGetAttribute)(&sm_count, GPUMultiProcessorCount, dev);
if (err != GPU(Success)) { return err; }
}
static const bool max_smem_configed = [=]() {
int max_smem_size = 0;
GPU(Error_t) err =
GPU(DeviceGetAttribute)(&max_smem_size, GPUMaxSharedMemoryPerBlockOptin, dev);
if (err != GPU(Success)) { return false; }
err = MaximizeDynamicSharedMemorySize(
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1>,
max_smem_size);
if (err != GPU(Success)) { return false; }
err = MaximizeDynamicSharedMemorySize(
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2>,
max_smem_size);
if (err != GPU(Success)) { return false; }
err = MaximizeDynamicSharedMemorySize(
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3>,
max_smem_size);
if (err != GPU(Success)) { return false; }
err = MaximizeDynamicSharedMemorySize(
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4>,
max_smem_size);
if (err != GPU(Success)) { return false; }
return true;
}();
const size_t smem = cols * sizeof(typename LOAD::LoadType);
int max_active_blocks_conf_1;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_1,
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1>,
block_size_conf_1, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_1 <= 0) {
*success = false;
return cudaSuccess;
return GPU(Success);
}
int max_active_blocks_conf_4;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_4,
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4>,
block_size_conf_4, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_4 == max_active_blocks_conf_1) {
if (max_active_blocks_conf_4 == max_active_blocks_conf_1
|| (max_active_blocks_conf_4 > 0 && rows <= sm_count)) {
*success = true;
return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4>(
stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);
}
int max_active_blocks_conf_3;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_3,
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3>,
block_size_conf_3, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_3 == max_active_blocks_conf_1) {
if (max_active_blocks_conf_3 == max_active_blocks_conf_1
|| (max_active_blocks_conf_3 > 0 && rows <= sm_count)) {
*success = true;
return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3>(
stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);
}
int max_active_blocks_conf_2;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_2,
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2>,
block_size_conf_2, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {
if (max_active_blocks_conf_2 == max_active_blocks_conf_1
|| (max_active_blocks_conf_2 > 0 && rows <= sm_count)) {
*success = true;
return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2>(
stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);
}
*success = true;
return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1>(
stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);
......@@ -682,13 +793,13 @@ inline cudaError_t TryDispatchLayerNormBlockSMemImplBlockSize(
template<typename LOAD, typename STORE, typename ComputeType>
struct TryDispatchLayerNormBlockSMemImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
GPU(Error_t) operator()(GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance, bool* success) {
if (cols % 4 == 0) {
if (cols % 4 == 0 && CanPackAs<LOAD>(load, 4) && CanPackAs<STORE>(store, 4)) {
return TryDispatchLayerNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 4>(
stream, load, store, rows, cols, epsilon, mean, inv_variance, success);
} else if (cols % 2 == 0) {
} else if (cols % 2 == 0 && CanPackAs<LOAD>(load, 2) && CanPackAs<STORE>(store, 2)) {
return TryDispatchLayerNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 2>(
stream, load, store, rows, cols, epsilon, mean, inv_variance, success);
} else {
......@@ -699,7 +810,7 @@ struct TryDispatchLayerNormBlockSMemImplPackSize {
};
template<typename LOAD, typename STORE, typename ComputeType>
inline cudaError_t TryDispatchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store,
inline GPU(Error_t) TryDispatchLayerNormBlockSMemImpl(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance, bool* success) {
......@@ -708,9 +819,10 @@ inline cudaError_t TryDispatchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD l
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
__global__ void LayerNormBlockUncachedImpl(LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon,
ComputeType* mean, ComputeType* inv_variance) {
__global__ void __launch_bounds__(1024)
LayerNormBlockUncachedImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean, ComputeType* inv_variance) {
using LoadType = typename LOAD::LoadType;
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = static_cast<int>(cols) / pack_size;
......@@ -719,11 +831,11 @@ __global__ void LayerNormBlockUncachedImpl(LOAD load, STORE store, const int64_t
ComputeType thread_m2 = 0;
ComputeType thread_count = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
LoadType pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
WelfordCombine(pack[i], &thread_mean, &thread_m2, &thread_count);
WelfordCombine(static_cast<ComputeType>(pack[i]), &thread_mean, &thread_m2, &thread_count);
}
}
ComputeType row_mean = 0;
......@@ -738,18 +850,21 @@ __global__ void LayerNormBlockUncachedImpl(LOAD load, STORE store, const int64_t
inv_variance[row] = row_inv_var;
}
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
LoadType pack[pack_size];
ComputeType dst_pack[pack_size];
const int pack_offset = pack_id * pack_size;
load.template load<pack_size>(pack, row, pack_offset);
#pragma unroll
for (int i = 0; i < pack_size; ++i) { pack[i] = (pack[i] - row_mean) * row_inv_var; }
store.template store<pack_size>(pack, row, pack_offset);
for (int i = 0; i < pack_size; ++i) {
dst_pack[i] = (static_cast<ComputeType>(pack[i]) - row_mean) * row_inv_var;
}
store.template store<pack_size>(dst_pack, row, pack_offset);
}
}
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size>
inline cudaError_t LaunchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,
inline GPU(Error_t) LaunchLayerNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
......@@ -757,25 +872,25 @@ inline cudaError_t LaunchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD lo
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err =
GPU(Error_t) err =
GetNumBlocks(LayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size>,
block_size, 0, rows, waves, &grid_dim_x);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
LayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, 0, stream>>>(load, store, rows, cols, epsilon, mean, inv_variance);
return cudaPeekAtLastError();
return GPU(PeekAtLastError)();
}
template<typename LOAD, typename STORE, typename ComputeType>
struct DispatchLayerNormBlockUncachedImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
GPU(Error_t) operator()(GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
if (cols % 4 == 0) {
if (cols % 4 == 0 && CanPackAs<LOAD>(load, 4) && CanPackAs<STORE>(store, 4)) {
return LaunchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, 4>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else if (cols % 2 == 0) {
} else if (cols % 2 == 0 && CanPackAs<LOAD>(load, 2) && CanPackAs<STORE>(store, 2)) {
return LaunchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, 2>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else {
......@@ -786,7 +901,7 @@ struct DispatchLayerNormBlockUncachedImplPackSize {
};
template<typename LOAD, typename STORE, typename ComputeType>
inline cudaError_t DispatchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,
inline GPU(Error_t) DispatchLayerNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
......@@ -795,8 +910,8 @@ inline cudaError_t DispatchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD
}
template<typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, GPU(Error_t)>::type
DispatchLayerNorm(GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
if (cols <= 1024) {
......@@ -805,22 +920,22 @@ DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t row
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err = TryDispatchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType>(
GPU(Error_t) err = TryDispatchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType>(
stream, load, store, rows, cols, epsilon, mean, inv_variance,
&dispatch_smem_impl_success);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (!dispatch_smem_impl_success) {
return DispatchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
return cudaSuccess;
return GPU(Success);
}
}
template<typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
inline typename std::enable_if<std::is_same<ComputeType, double>::value, GPU(Error_t)>::type
DispatchLayerNorm(GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
return DispatchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType>(
......@@ -836,18 +951,22 @@ dx = cols * dy - sum_stats1 - normalized * sum_stats2
dx *= inv_var / cols
*/
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int cols_per_thread, int thread_group_width, int rows_per_access,
bool padding>
int pack_size, int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,
int rows_per_access>
__global__ void LayerNormGradWarpImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
static_assert(cols_per_thread % pack_size == 0, "");
constexpr int pack_per_thread = cols_per_thread / pack_size;
assert(cols <= cols_per_thread * thread_group_width);
using LoadTypeX = typename LOAD_X::LoadType;
using LoadTypeDy = typename LOAD_SCALED_DY::LoadType;
static_assert(max_cols_per_thread % pack_size == 0, "");
static_assert(min_cols_per_thread % pack_size == 0, "");
constexpr int max_num_packs = max_cols_per_thread / pack_size;
constexpr int min_num_packs = min_cols_per_thread / pack_size;
assert(cols <= max_cols_per_thread * thread_group_width);
static_assert(thread_group_width <= kWarpSize, "");
static_assert(kWarpSize % thread_group_width == 0, "");
ComputeType normalized_buf[rows_per_access][cols_per_thread];
ComputeType dy_buf[rows_per_access][cols_per_thread];
ComputeType normalized_buf[rows_per_access][max_cols_per_thread];
ComputeType dy_buf[rows_per_access][max_cols_per_thread];
const ComputeType one_over_cols = static_cast<ComputeType>(1.0) / static_cast<ComputeType>(cols);
const int64_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;
const int64_t num_global_thread_group = gridDim.x * blockDim.y;
......@@ -867,18 +986,40 @@ __global__ void LayerNormGradWarpImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_
ComputeType* row_normalized_buf = normalized_buf[row_id];
ComputeType* row_dy_buf = dy_buf[row_id];
#pragma unroll
for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) {
for (int pack_id = 0; pack_id < min_num_packs; ++pack_id) {
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
const int pack_offset = pack_id * pack_size;
if (!padding || col < cols) {
load_x.template load<pack_size>(row_normalized_buf + pack_offset, global_row_id, col);
load_scaled_dy.template load<pack_size>(row_dy_buf + pack_offset, global_row_id, col);
LoadTypeX pack_x[pack_size];
LoadTypeDy pack_dy[pack_size];
load_x.template load<pack_size>(pack_x, global_row_id, col);
load_scaled_dy.template load<pack_size>(pack_dy, global_row_id, col);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
const int col_id = pack_offset + i;
// row_normalized_buf store x
row_normalized_buf[col_id] =
(static_cast<ComputeType>(pack_x[i]) - mean_val) * inv_variance_buf[row_id];
row_dy_buf[col_id] = static_cast<ComputeType>(pack_dy[i]);
sum_stats1[row_id] += row_dy_buf[col_id];
sum_stats2[row_id] += row_dy_buf[col_id] * row_normalized_buf[col_id];
}
}
#pragma unroll
for (int pack_id = min_num_packs; pack_id < max_num_packs; ++pack_id) {
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
const int pack_offset = pack_id * pack_size;
if (col < cols) {
LoadTypeX pack_x[pack_size];
LoadTypeDy pack_dy[pack_size];
load_x.template load<pack_size>(pack_x, global_row_id, col);
load_scaled_dy.template load<pack_size>(pack_dy, global_row_id, col);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
const int col_id = pack_offset + i;
// row_normalized_buf store x
row_normalized_buf[col_id] =
(row_normalized_buf[col_id] - mean_val) * inv_variance_buf[row_id];
(static_cast<ComputeType>(pack_x[i]) - mean_val) * inv_variance_buf[row_id];
row_dy_buf[col_id] = static_cast<ComputeType>(pack_dy[i]);
sum_stats1[row_id] += row_dy_buf[col_id];
sum_stats2[row_id] += row_dy_buf[col_id] * row_normalized_buf[col_id];
}
......@@ -901,16 +1042,29 @@ __global__ void LayerNormGradWarpImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_
ComputeType* row_dy_buf = dy_buf[row_id];
const ComputeType inv_variance_over_cols = inv_variance_buf[row_id] * one_over_cols;
#pragma unroll
for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) {
for (int pack_id = 0; pack_id < min_num_packs; ++pack_id) {
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
if (!padding || col < cols) {
const int pack_offset = pack_id * pack_size;
for (int i = 0; i < pack_size; ++i) {
const int col_id = pack_offset + i;
row_dy_buf[col_id] = (cols * row_dy_buf[col_id] - warp_sum_stats1[row_id]
- row_normalized_buf[col_id] * warp_sum_stats2[row_id])
* inv_variance_over_cols;
}
store.template store<pack_size>(row_dy_buf + pack_offset, global_row_id, col);
}
#pragma unroll
for (int pack_id = min_num_packs; pack_id < max_num_packs; ++pack_id) {
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
if (col < cols) {
const int pack_offset = pack_id * pack_size;
for (int i = 0; i < pack_size; ++i) {
const int col_id = pack_id * pack_size + i;
const int col_id = pack_offset + i;
row_dy_buf[col_id] = (cols * row_dy_buf[col_id] - warp_sum_stats1[row_id]
- row_normalized_buf[col_id] * warp_sum_stats2[row_id])
* inv_variance_over_cols;
}
store.template store<pack_size>(row_dy_buf + pack_id * pack_size, global_row_id, col);
store.template store<pack_size>(row_dy_buf + pack_offset, global_row_id, col);
}
}
}
......@@ -918,9 +1072,9 @@ __global__ void LayerNormGradWarpImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_
}
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int cols_per_thread, int thread_group_width, int rows_per_access,
bool padding>
inline cudaError_t LaunchLayerNormGradWarpImpl(cudaStream_t stream, LOAD_X load_x,
int pack_size, int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,
int rows_per_access>
inline GPU(Error_t) LaunchLayerNormGradWarpImpl(GPU(Stream_t) stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance, const int64_t rows,
......@@ -934,143 +1088,100 @@ inline cudaError_t LaunchLayerNormGradWarpImpl(cudaStream_t stream, LOAD_X load_
(rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(
LayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
cols_per_thread, thread_group_width, rows_per_access, padding>,
block_size, 0, num_blocks, waves, &grid_dim_x);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err =
GetNumBlocks(LayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
max_cols_per_thread, min_cols_per_thread,
thread_group_width, rows_per_access>,
block_size, 0, num_blocks, waves, &grid_dim_x);
if (err != GPU(Success)) { return err; }
}
LayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, padding>
LayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size, max_cols_per_thread,
min_cols_per_thread, thread_group_width, rows_per_access>
<<<grid_dim_x, block_dim, 0, stream>>>(load_x, load_scaled_dy, store, mean, inv_variance,
rows, cols);
return cudaPeekAtLastError();
return GPU(PeekAtLastError)();
}
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int cols_per_thread, int thread_group_width, int rows_per_access>
inline cudaError_t DispatchLayerNormGradWarpImplPadding(cudaStream_t stream, LOAD_X load_x,
int pack_size, int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,
int rows_per_access>
inline GPU(Error_t) DispatchLayerNormGradWarpImplPadding(GPU(Stream_t) stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
if (cols == cols_per_thread * thread_group_width) {
if (cols == max_cols_per_thread * thread_group_width) {
// when not padding, min_cols_per_thread must equals to max_cols_per_thread, pass
// max_cols_per_thread as min_cols_per_thread and max_cols_per_thread param.
return LaunchLayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
cols_per_thread, thread_group_width, rows_per_access, false>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
max_cols_per_thread, max_cols_per_thread, thread_group_width,
rows_per_access>(stream, load_x, load_scaled_dy, store, mean,
inv_variance, rows, cols);
} else {
return LaunchLayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
cols_per_thread, thread_group_width, rows_per_access, true>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
}
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size>
typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchLayerNormGradWarpImplCols(
cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
if (cols <= 0) { return cudaErrorInvalidValue; }
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 2>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} else { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, col, kWarpSize, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
}
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
max_cols_per_thread, min_cols_per_thread, thread_group_width,
rows_per_access>(stream, load_x, load_scaled_dy, store, mean,
inv_variance, rows, cols);
}
}
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size>
typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchLayerNormGradWarpImplCols(
cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,
typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormGradWarpImplCols(
GPU(Stream_t) stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
if (cols <= 0) { return cudaErrorInvalidValue; }
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 2>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} else { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \
if (cols <= 0) { return GPU(ErrorInvalidValue); }
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, pack_size, 0, thread_group_width, 2>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} else { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, pack_size, 0, thread_group_width, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, col, kWarpSize, 1>( \
pack_size, max_col, min_col, kWarpSize, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(32)
DEFINE_ONE_ELIF(2, 1)
DEFINE_ONE_ELIF(4, 2)
DEFINE_ONE_ELIF(8, 4)
DEFINE_ONE_ELIF(12, 8)
DEFINE_ONE_ELIF(16, 12)
DEFINE_ONE_ELIF(20, 16)
DEFINE_ONE_ELIF(24, 20)
DEFINE_ONE_ELIF(28, 24)
DEFINE_ONE_ELIF(32, 28)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
return GPU(ErrorInvalidValue);
}
}
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
struct DispatchLayerNormGradWarpImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
GPU(Error_t) operator()(GPU(Stream_t) stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
if (cols % 2 == 0) {
return DispatchLayerNormGradWarpImplCols<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, 2>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
} else {
return DispatchLayerNormGradWarpImplCols<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, 1>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
return DispatchLayerNormGradWarpImplCols<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, 1>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
};
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
inline cudaError_t DispatchLayerNormGradWarpImpl(cudaStream_t stream, LOAD_X load_x,
inline GPU(Error_t) DispatchLayerNormGradWarpImpl(GPU(Stream_t) stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance,
......@@ -1085,9 +1196,11 @@ __global__ void LayerNormGradBlockSMemImpl(LOAD_X load_x, LOAD_SCALED_DY load_sc
STORE store, const ComputeType* mean,
const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
using LoadTypeX = typename LOAD_X::LoadType;
using LoadTypeDy = typename LOAD_SCALED_DY::LoadType;
extern __shared__ __align__(sizeof(double)) unsigned char grad_shared_buf[];
auto* normalized_buf = reinterpret_cast<ComputeType*>(grad_shared_buf);
auto* dy_buf = normalized_buf + cols;
auto* normalized_buf = reinterpret_cast<LoadTypeX*>(grad_shared_buf);
auto* dy_buf = reinterpret_cast<LoadTypeDy*>(normalized_buf + cols);
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = static_cast<int>(cols) / pack_size;
......@@ -1099,18 +1212,19 @@ __global__ void LayerNormGradBlockSMemImpl(LOAD_X load_x, LOAD_SCALED_DY load_sc
const ComputeType inv_variance_val = inv_variance[row];
const ComputeType inv_variance_over_cols = inv_variance_val * one_over_cols;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType x_pack[pack_size];
ComputeType dy_pack[pack_size];
LoadTypeX x_pack[pack_size];
LoadTypeDy dy_pack[pack_size];
load_x.template load<pack_size>(x_pack, row, pack_id * pack_size);
load_scaled_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
const int buf_offset = i * num_packs + pack_id;
ComputeType normalized = (x_pack[i] - mean_val) * inv_variance_val;
normalized_buf[buf_offset] = normalized;
ComputeType normalized =
(static_cast<ComputeType>(x_pack[i]) - mean_val) * inv_variance_val;
normalized_buf[buf_offset] = static_cast<LoadTypeX>(normalized);
dy_buf[buf_offset] = dy_pack[i];
sum_stats1 += dy_pack[i];
sum_stats2 += dy_pack[i] * normalized;
sum_stats1 += static_cast<ComputeType>(dy_pack[i]);
sum_stats2 += static_cast<ComputeType>(dy_pack[i]) * normalized;
}
}
const ComputeType row_sum_stats1 = BlockAllReduce<SumOp, ComputeType, block_size>(sum_stats1);
......@@ -1120,8 +1234,8 @@ __global__ void LayerNormGradBlockSMemImpl(LOAD_X load_x, LOAD_SCALED_DY load_sc
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
const int buf_offset = i * num_packs + pack_id;
pack[i] = (cols * dy_buf[buf_offset] - row_sum_stats1
- normalized_buf[buf_offset] * row_sum_stats2)
pack[i] = (cols * static_cast<ComputeType>(dy_buf[buf_offset]) - row_sum_stats1
- static_cast<ComputeType>(normalized_buf[buf_offset]) * row_sum_stats2)
* inv_variance_over_cols;
}
store.template store<pack_size>(pack, row, pack_id * pack_size);
......@@ -1131,7 +1245,7 @@ __global__ void LayerNormGradBlockSMemImpl(LOAD_X load_x, LOAD_SCALED_DY load_sc
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int block_size>
inline cudaError_t LaunchLayerNormGradBlockSMemImpl(cudaStream_t stream, LOAD_X load_x,
inline GPU(Error_t) LaunchLayerNormGradBlockSMemImpl(GPU(Stream_t) stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance, int smem,
......@@ -1139,86 +1253,139 @@ inline cudaError_t LaunchLayerNormGradBlockSMemImpl(cudaStream_t stream, LOAD_X
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE,
GPU(Error_t) err = GetNumBlocks(LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE,
ComputeType, pack_size, block_size>,
block_size, smem, rows, waves, &grid_dim_x);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, smem, stream>>>(load_x, load_scaled_dy, store, mean, inv_variance,
rows, cols);
return cudaPeekAtLastError();
return GPU(PeekAtLastError)();
}
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size>
inline cudaError_t TryDispatchLayerNormGradBlockSMemImplBlockSize(
cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,
inline GPU(Error_t) TryDispatchLayerNormGradBlockSMemImplBlockSize(
GPU(Stream_t) stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows,
const int64_t cols, bool* success) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
const size_t smem = cols * sizeof(ComputeType) * 2;
int dev = 0;
{
GPU(Error_t) err = GPU(GetDevice)(&dev);
if (err != GPU(Success)) { return err; }
}
int sm_count = 0;
{
GPU(Error_t) err = GPU(DeviceGetAttribute)(&sm_count, GPUMultiProcessorCount, dev);
if (err != GPU(Success)) { return err; }
}
static const bool max_smem_configed = [=]() {
int max_smem_size = 0;
GPU(Error_t) err =
GPU(DeviceGetAttribute)(&max_smem_size, GPUMaxSharedMemoryPerBlockOptin, dev);
if (err != GPU(Success)) { return false; }
err = MaximizeDynamicSharedMemorySize(
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_1>,
max_smem_size);
if (err != GPU(Success)) { return false; }
err = MaximizeDynamicSharedMemorySize(
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_2>,
max_smem_size);
if (err != GPU(Success)) { return false; }
err = MaximizeDynamicSharedMemorySize(
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_3>,
max_smem_size);
if (err != GPU(Success)) { return false; }
err = MaximizeDynamicSharedMemorySize(
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_4>,
max_smem_size);
if (err != GPU(Success)) { return false; }
return true;
}();
using LoadTypeX = typename LOAD_X::LoadType;
using LoadTypeDy = typename LOAD_SCALED_DY::LoadType;
const size_t smem = cols * (sizeof(LoadTypeX) + sizeof(LoadTypeDy));
int max_active_blocks_conf_1;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_1,
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_1>,
block_size_conf_1, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_1 <= 0) {
*success = false;
return cudaSuccess;
return GPU(Success);
}
int max_active_blocks_conf_4;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_4,
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_4>,
block_size_conf_4, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_4 == max_active_blocks_conf_1) {
if (max_active_blocks_conf_4 == max_active_blocks_conf_1
|| (max_active_blocks_conf_4 > 0 && rows <= sm_count)) {
*success = true;
return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_4>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols);
}
int max_active_blocks_conf_3;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_3,
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_3>,
block_size_conf_3, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_3 == max_active_blocks_conf_1) {
if (max_active_blocks_conf_3 == max_active_blocks_conf_1
|| (max_active_blocks_conf_3 > 0 && rows <= sm_count)) {
*success = true;
return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_3>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols);
}
int max_active_blocks_conf_2;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_2,
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_2>,
block_size_conf_2, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {
if (max_active_blocks_conf_2 == max_active_blocks_conf_1
|| (max_active_blocks_conf_2 > 0 && rows <= sm_count)) {
*success = true;
return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_2>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols);
}
*success = true;
return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_1>(stream, load_x, load_scaled_dy, store,
......@@ -1227,10 +1394,11 @@ inline cudaError_t TryDispatchLayerNormGradBlockSMemImplBlockSize(
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
struct TryDispatchLayerNormGradBlockSMemImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
GPU(Error_t) operator()(GPU(Stream_t) stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols, bool* success) {
if (cols % 2 == 0) {
if (cols % 2 == 0 && CanPackAs<LOAD_X>(load_x, 2)
&& CanPackAs<LOAD_SCALED_DY>(load_scaled_dy, 2) && CanPackAs<STORE>(store, 2)) {
return TryDispatchLayerNormGradBlockSMemImplBlockSize<LOAD_X, LOAD_SCALED_DY, STORE,
ComputeType, 2>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols, success);
......@@ -1243,7 +1411,7 @@ struct TryDispatchLayerNormGradBlockSMemImplPackSize {
};
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
inline cudaError_t TryDispatchLayerNormGradBlockSMemImpl(cudaStream_t stream, LOAD_X load_x,
inline GPU(Error_t) TryDispatchLayerNormGradBlockSMemImpl(GPU(Stream_t) stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance,
......@@ -1260,6 +1428,8 @@ __global__ void LayerNormGradBlockUncachedImpl(LOAD_X load_x, LOAD_SCALED_DY loa
STORE store, const ComputeType* mean,
const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
using LoadTypeX = typename LOAD_X::LoadType;
using LoadTypeDy = typename LOAD_SCALED_DY::LoadType;
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = static_cast<int>(cols) / pack_size;
......@@ -1271,75 +1441,134 @@ __global__ void LayerNormGradBlockUncachedImpl(LOAD_X load_x, LOAD_SCALED_DY loa
ComputeType sum_stats1 = 0;
ComputeType sum_stats2 = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType x_pack[pack_size];
ComputeType dy_pack[pack_size];
load_x.template load<pack_size>(x_pack, row, pack_id * pack_size);
load_scaled_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);
const int pack_offset = pack_id * pack_size;
LoadTypeX x_pack[pack_size];
LoadTypeDy dy_pack[pack_size];
load_x.template load<pack_size>(x_pack, row, pack_offset);
load_scaled_dy.template load<pack_size>(dy_pack, row, pack_offset);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
sum_stats1 += dy_pack[i];
sum_stats2 += dy_pack[i] * (x_pack[i] - mean_val) * inv_variance_val;
sum_stats1 += static_cast<ComputeType>(dy_pack[i]);
sum_stats2 += static_cast<ComputeType>(dy_pack[i])
* (static_cast<ComputeType>(x_pack[i]) - mean_val) * inv_variance_val;
}
}
const ComputeType row_sum_stats1 = BlockAllReduce<SumOp, ComputeType, block_size>(sum_stats1);
const ComputeType row_sum_stats2 = BlockAllReduce<SumOp, ComputeType, block_size>(sum_stats2);
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType x_pack[pack_size];
ComputeType dy_pack[pack_size];
load_x.template load<pack_size>(x_pack, row, pack_id * pack_size);
load_scaled_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);
const int pack_offset = pack_id * pack_size;
LoadTypeX x_pack[pack_size];
LoadTypeDy dy_pack[pack_size];
ComputeType dx_pack[pack_size];
load_x.template load<pack_size>(x_pack, row, pack_offset);
load_scaled_dy.template load<pack_size>(dy_pack, row, pack_offset);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
dy_pack[i] = (cols * dy_pack[i] - row_sum_stats1
- (x_pack[i] - mean_val) * inv_variance_val * row_sum_stats2)
* inv_variance_over_cols;
dx_pack[i] =
(cols * static_cast<ComputeType>(dy_pack[i]) - row_sum_stats1
- (static_cast<ComputeType>(x_pack[i]) - mean_val) * inv_variance_val * row_sum_stats2)
* inv_variance_over_cols;
}
store.template store<pack_size>(dy_pack, row, pack_id * pack_size);
store.template store<pack_size>(dx_pack, row, pack_offset);
}
}
}
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size>
inline cudaError_t LaunchLayerNormGradBlockUncachedImpl(cudaStream_t stream, LOAD_X load_x,
int pack_size, int block_size>
inline GPU(Error_t) LaunchLayerNormGradBlockUncachedImpl(GPU(Stream_t) stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
constexpr int block_size = 1024;
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err =
GPU(Error_t) err =
GetNumBlocks(LayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
pack_size, block_size>,
block_size, 0, rows, waves, &grid_dim_x);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
LayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, 0, stream>>>(load_x, load_scaled_dy, store, mean, inv_variance,
rows, cols);
return cudaPeekAtLastError();
return GPU(PeekAtLastError)();
}
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size>
inline GPU(Error_t) TryDispatchLaunchLayerNormGradBlockUncachedImplBlockSize(
GPU(Stream_t) stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
int max_active_blocks = 0;
constexpr int block_size_conf_1 = 1024;
{
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks,
LayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_1>,
block_size_conf_1, 0);
if (max_active_blocks > 0) {
return LaunchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
pack_size, block_size_conf_1>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
}
constexpr int block_size_conf_2 = 512;
{
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks,
LayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_2>,
block_size_conf_2, 0);
if (max_active_blocks > 0) {
return LaunchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
pack_size, block_size_conf_2>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
}
constexpr int block_size_conf_3 = 256;
{
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks,
LayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_3>,
block_size_conf_2, 0);
if (max_active_blocks > 0) {
return LaunchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
pack_size, block_size_conf_3>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
}
constexpr int block_size_conf_4 = 128;
return LaunchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_4>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
struct DispatchLayerNormGradBlockUncachedImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
GPU(Error_t) operator()(GPU(Stream_t) stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
if (cols % 2 == 0 && cols > kWarpSize) {
return LaunchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, 2>(
if (cols % 2 == 0 && CanPackAs<LOAD_X>(load_x, 2)
&& CanPackAs<LOAD_SCALED_DY>(load_scaled_dy, 2) && CanPackAs<STORE>(store, 2)
&& cols > kWarpSize) {
return TryDispatchLaunchLayerNormGradBlockUncachedImplBlockSize<LOAD_X, LOAD_SCALED_DY, STORE,
ComputeType, 2>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
} else {
return LaunchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, 1>(
return TryDispatchLaunchLayerNormGradBlockUncachedImplBlockSize<LOAD_X, LOAD_SCALED_DY, STORE,
ComputeType, 1>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
}
};
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
inline cudaError_t DispatchLayerNormGradBlockUncachedImpl(cudaStream_t stream, LOAD_X load_x,
inline GPU(Error_t) DispatchLayerNormGradBlockUncachedImpl(GPU(Stream_t) stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean,
const ComputeType* inv_variance,
......@@ -1350,8 +1579,8 @@ inline cudaError_t DispatchLayerNormGradBlockUncachedImpl(cudaStream_t stream, L
}
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLayerNormGrad(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, GPU(Error_t)>::type
DispatchLayerNormGrad(GPU(Stream_t) stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
if (cols <= 1024) {
......@@ -1360,23 +1589,23 @@ DispatchLayerNormGrad(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_sc
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err =
GPU(Error_t) err =
TryDispatchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols,
&dispatch_smem_impl_success);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (!dispatch_smem_impl_success) {
return DispatchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
return cudaSuccess;
return GPU(Success);
}
}
template<typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLayerNormGrad(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
inline typename std::enable_if<std::is_same<ComputeType, double>::value, GPU(Error_t)>::type
DispatchLayerNormGrad(GPU(Stream_t) stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
return DispatchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>(
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_CUDA_RMS_NORM_H_
#define ONEFLOW_CORE_CUDA_RMS_NORM_H_
#include "oneflow/core/cuda/layer_norm.cuh"
namespace oneflow {
namespace cuda {
namespace rms_norm {
#ifdef WITH_ROCM
constexpr int kWarpSize = 64;
#else
constexpr int kWarpSize = 32;
#endif
template<typename T>
__inline__ __device__ T WarpReduceSum(T val) {
#ifdef WITH_ROCM
for (int mask = 32; mask > 0; mask /= 2) { val += __shfl_down(val, mask); }
#else
for (int mask = 16; mask > 0; mask /= 2) { val += __shfl_down_sync(0xffffffff, val, mask); }
#endif
return val;
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size,
int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,
int rows_per_access, bool padding>
__global__ void RmsNormWarpImpl(LOAD load, STORE store, const int nrow, const int ncol,
const double eps, ComputeType* inv_rms) {
static_assert(max_cols_per_thread % pack_size == 0, "");
static_assert(min_cols_per_thread % pack_size == 0, "");
static_assert(thread_group_width <= kWarpSize, "");
static_assert(kWarpSize % thread_group_width == 0, "");
constexpr int max_packs = max_cols_per_thread / pack_size;
constexpr int min_packs = min_cols_per_thread / pack_size;
assert(ncol <= max_cols_per_thread * thread_group_width);
ComputeType buf[rows_per_access][max_cols_per_thread];
const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;
const int num_global_thread_groups = gridDim.x * blockDim.y;
for (int row_i = global_thread_group_id; row_i < nrow; row_i += num_global_thread_groups) {
ComputeType thread_square_sum[rows_per_access];
#pragma unroll
for (int row_j = 0; row_j < rows_per_access; ++row_j) {
thread_square_sum[row_j] = 0;
ComputeType* row_buf = buf[row_j];
const int row = row_i * rows_per_access + row_j;
#pragma unroll
for (int pack_i = 0; pack_i < min_packs; ++pack_i) {
const int pack_offset = pack_i * pack_size;
const int col = (pack_i * thread_group_width + threadIdx.x) * pack_size;
load.template load<pack_size>(row_buf + pack_offset, row, col);
#pragma unroll
for (int pack_j = 0; pack_j < pack_size; ++pack_j) {
thread_square_sum[row_j] += row_buf[pack_offset + pack_j] * row_buf[pack_offset + pack_j];
}
}
#pragma unroll
for (int pack_i = min_packs; pack_i < max_packs; ++pack_i) {
const int pack_offset = pack_i * pack_size;
const int col = (pack_i * thread_group_width + threadIdx.x) * pack_size;
if (!padding || col < ncol) {
load.template load<pack_size>(row_buf + pack_offset, row, col);
#pragma unroll
for (int pack_j = 0; pack_j < pack_size; ++pack_j) {
thread_square_sum[row_j] +=
row_buf[pack_offset + pack_j] * row_buf[pack_offset + pack_j];
}
} else {
#pragma unroll
for (int pack_j = 0; pack_j < pack_size; ++pack_j) {
row_buf[pack_i * pack_size + pack_j] = 0;
}
}
}
}
ComputeType warp_square_sum[rows_per_access];
#pragma unroll
for (int row_j = 0; row_j < rows_per_access; ++row_j) {
const int row = row_i * rows_per_access + row_j;
ComputeType* row_buf = buf[row_j];
warp_square_sum[row_j] =
layer_norm::WarpAllReduce<layer_norm::SumOp, ComputeType, thread_group_width>(
thread_square_sum[row_j]);
ComputeType row_square_mean =
layer_norm::Div(warp_square_sum[row_j], static_cast<ComputeType>(ncol));
ComputeType row_inv_rms = layer_norm::Rsqrt(row_square_mean + static_cast<ComputeType>(eps));
if (threadIdx.x == 0) { inv_rms[row] = row_inv_rms; }
#pragma unroll
for (int col = 0; col < max_cols_per_thread; ++col) { row_buf[col] *= row_inv_rms; }
#pragma unroll
for (int pack_i = 0; pack_i < min_packs; ++pack_i) {
const int col = (pack_i * thread_group_width + threadIdx.x) * pack_size;
store.template store<pack_size>(row_buf + pack_i * pack_size, row, col);
}
#pragma unroll
for (int pack_i = min_packs; pack_i < max_packs; ++pack_i) {
const int col = (pack_i * thread_group_width + threadIdx.x) * pack_size;
if (!padding || col < ncol) {
store.template store<pack_size>(row_buf + pack_i * pack_size, row, col);
}
}
}
}
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size,
int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,
int rows_per_access, bool padding>
GPU(Error_t) LaunchRmsNormWarpImpl(GPU(Stream_t) stream, LOAD load, STORE store, const int64_t nrow,
const int64_t ncol, const double eps, ComputeType* inv_rms) {
constexpr int block_size = 128;
constexpr int waves = 32;
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
const int64_t num_blocks =
(nrow / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;
int grid_dim_x;
{
GPU(Error_t) err = layer_norm::GetNumBlocks(
RmsNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread,
min_cols_per_thread, thread_group_width, rows_per_access, padding>,
block_size, 0, num_blocks, waves, &grid_dim_x);
if (err != GPU(Success)) { return err; }
}
dim3 block_dim(thread_group_width, thread_groups_per_block);
RmsNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread, min_cols_per_thread,
thread_group_width, rows_per_access, padding>
<<<grid_dim_x, block_dim, 0, stream>>>(load, store, static_cast<int>(nrow),
static_cast<int>(ncol), eps, inv_rms);
return GPU(PeekAtLastError)();
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size,
int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,
int rows_per_access>
GPU(Error_t) DispatchLaunchRmsNormWarpImplPadding(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t nrow, const int64_t ncol,
const double eps, ComputeType* inv_rms) {
if (ncol == max_cols_per_thread * thread_group_width) {
// when not padding, min_cols_per_thread must equals to max_cols_per_thread, pass
// max_cols_per_thread as min_cols_per_thread and max_cols_per_thread param.
return LaunchRmsNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread,
max_cols_per_thread, thread_group_width, rows_per_access, false>(
stream, load, store, nrow, ncol, eps, inv_rms);
} else {
return LaunchRmsNormWarpImpl<LOAD, STORE, ComputeType, pack_size, max_cols_per_thread,
min_cols_per_thread, thread_group_width, rows_per_access, true>(
stream, load, store, nrow, ncol, eps, inv_rms);
}
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size>
typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLaunchRmsNormWarpImplCols(
GPU(Stream_t) stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol,
const double eps, ComputeType* inv_rms) {
if (ncol <= 0) { return GPU(ErrorInvalidValue); }
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (ncol <= (thread_group_width)*pack_size) { \
if (nrow % 2 == 0) { \
return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
0, thread_group_width, 2>( \
stream, load, store, nrow, ncol, eps, inv_rms); \
} else { \
return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
0, thread_group_width, 1>( \
stream, load, store, nrow, ncol, eps, inv_rms); \
} \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (ncol <= (max_col)*kWarpSize) { \
return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, max_col, \
min_col, kWarpSize, 1>(stream, load, store, nrow, \
ncol, eps, inv_rms); \
}
DEFINE_ONE_ELIF(2, 1)
DEFINE_ONE_ELIF(4, 2)
DEFINE_ONE_ELIF(8, 4)
DEFINE_ONE_ELIF(12, 8)
DEFINE_ONE_ELIF(16, 12)
DEFINE_ONE_ELIF(20, 16)
DEFINE_ONE_ELIF(24, 20)
DEFINE_ONE_ELIF(28, 24)
DEFINE_ONE_ELIF(32, 28)
#undef DEFINE_ONE_ELIF
else {
return GPU(ErrorInvalidValue);
}
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size>
typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchLaunchRmsNormWarpImplCols(
GPU(Stream_t) stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol,
const double eps, ComputeType* inv_rms) {
if (ncol <= 0) { return GPU(ErrorInvalidValue); }
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (ncol <= (thread_group_width)*pack_size) { \
if (nrow % 2 == 0) { \
return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
0, thread_group_width, 2>( \
stream, load, store, nrow, ncol, eps, inv_rms); \
} else { \
return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
0, thread_group_width, 1>( \
stream, load, store, nrow, ncol, eps, inv_rms); \
} \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if ((ncol <= (max_col)*kWarpSize) && (ncol > (min_col)*kWarpSize)) { \
return DispatchLaunchRmsNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, max_col, \
min_col, kWarpSize, 1>(stream, load, store, nrow, \
ncol, eps, inv_rms); \
}
DEFINE_ONE_ELIF(4, 2)
DEFINE_ONE_ELIF(8, 4)
DEFINE_ONE_ELIF(12, 8)
DEFINE_ONE_ELIF(16, 12)
DEFINE_ONE_ELIF(20, 16)
DEFINE_ONE_ELIF(24, 20)
DEFINE_ONE_ELIF(28, 24)
DEFINE_ONE_ELIF(32, 28)
#undef DEFINE_ONE_ELIF
else {
return GPU(ErrorInvalidValue);
}
}
template<typename LOAD, typename STORE, typename ComputeType>
GPU(Error_t) DispatchLaunchRmsNormWarpImplPackSize(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t nrow, const int64_t ncol,
const double eps, ComputeType* inv_rms) {
if (ncol % 2 == 0 && layer_norm::CanPackAs<LOAD>(load, 2)
&& layer_norm::CanPackAs<STORE>(store, 2)) {
return DispatchLaunchRmsNormWarpImplCols<LOAD, STORE, ComputeType, 2>(stream, load, store, nrow,
ncol, eps, inv_rms);
} else {
return DispatchLaunchRmsNormWarpImplCols<LOAD, STORE, ComputeType, 1>(stream, load, store, nrow,
ncol, eps, inv_rms);
}
}
template<typename LOAD, typename STORE, typename ComputeType>
GPU(Error_t) DispatchLaunchRmsNormWarpImpl(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t nrow, const int64_t ncol, const double eps,
ComputeType* inv_rms) {
return DispatchLaunchRmsNormWarpImplPackSize(stream, load, store, nrow, ncol, eps, inv_rms);
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
__global__ void RmsNormBlockSMemImpl(LOAD load, STORE store, const int nrow, const int ncol,
const double eps, ComputeType* inv_rms) {
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto* buf = reinterpret_cast<ComputeType*>(shared_buf);
assert(ncol % pack_size == 0);
const int num_packs = ncol / pack_size;
for (int row = blockIdx.x; row < nrow; row += gridDim.x) {
ComputeType thread_square_sum = 0;
for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += block_size) {
ComputeType pack[pack_size];
const int col = pack_i * pack_size;
load.template load<pack_size>(pack, row, col);
#pragma unroll
for (int pack_j = 0; pack_j < pack_size; ++pack_j) {
buf[pack_i * pack_size + pack_j] = pack[pack_j];
thread_square_sum += pack[pack_j] * pack[pack_j];
}
}
ComputeType row_square_sum =
layer_norm::BlockAllReduce<layer_norm::SumOp, ComputeType, block_size>(thread_square_sum);
ComputeType row_square_mean = layer_norm::Div(row_square_sum, static_cast<ComputeType>(ncol));
ComputeType row_inv_rms = layer_norm::Rsqrt(row_square_mean + static_cast<ComputeType>(eps));
if (threadIdx.x == 0) { inv_rms[row] = row_inv_rms; }
for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += block_size) {
ComputeType pack[pack_size];
#pragma unroll
for (int pack_j = 0; pack_j < pack_size; ++pack_j) {
pack[pack_j] = buf[pack_i * pack_size + pack_j] * row_inv_rms;
}
const int col = pack_i * pack_size;
store.template store<pack_size>(pack, row, col);
}
}
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
GPU(Error_t) LaunchRmsNormBlockSMemImpl(GPU(Stream_t) stream, LOAD load, STORE store,
size_t smem_size, const int64_t nrow, const int64_t ncol,
const double eps, ComputeType* inv_rms) {
constexpr int waves = 32;
int grid_dim_x;
{
GPU(Error_t) err = layer_norm::GetNumBlocks(
RmsNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size>, block_size,
smem_size, nrow, waves, &grid_dim_x);
if (err != GPU(Success)) { return err; }
}
RmsNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, smem_size, stream>>>(load, store, nrow, ncol, eps, inv_rms);
return GPU(PeekAtLastError)();
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size>
GPU(Error_t) TryDispatchLaunchRmsNormBlockSMemImplBlockSize(GPU(Stream_t) stream, LOAD load,
STORE store, const int64_t nrow,
const int64_t ncol, const double eps,
ComputeType* inv_rms, bool* success) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
const size_t smem_size = ncol * sizeof(ComputeType);
int max_active_blocks = 0;
int num_blocks = 0;
#define SELECT_BLOCK_SIZE_CONF(block_size_conf) \
{ \
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)( \
&num_blocks, RmsNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf>, \
block_size_conf, smem_size); \
if (err != GPU(Success)) { return err; } \
if (max_active_blocks == 0) { \
if (num_blocks <= max_active_blocks) { \
*success = false; \
return GPU(Success); \
} \
max_active_blocks = num_blocks; \
} else { \
if (num_blocks == max_active_blocks) { \
*success = true; \
return LaunchRmsNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf>( \
stream, load, store, smem_size, nrow, ncol, eps, inv_rms); \
} \
} \
}
SELECT_BLOCK_SIZE_CONF(block_size_conf_1)
SELECT_BLOCK_SIZE_CONF(block_size_conf_4)
SELECT_BLOCK_SIZE_CONF(block_size_conf_3)
SELECT_BLOCK_SIZE_CONF(block_size_conf_2)
#undef SELECT_BLOCK_SIZE_CONF
*success = true;
return LaunchRmsNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1>(
stream, load, store, smem_size, nrow, ncol, eps, inv_rms);
}
template<typename LOAD, typename STORE, typename ComputeType>
GPU(Error_t) TryDispatchLaunchRmsNormBlockSMemImplPackSize(GPU(Stream_t) stream, LOAD load,
STORE store, const int64_t nrow,
const int64_t ncol, const double eps,
ComputeType* inv_rms, bool* success) {
if (ncol % 4 == 0 && layer_norm::CanPackAs<LOAD>(load, 4)
&& layer_norm::CanPackAs<STORE>(store, 4)) {
return TryDispatchLaunchRmsNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 4>(
stream, load, store, nrow, ncol, eps, inv_rms, success);
} else if (ncol % 2 == 0 && layer_norm::CanPackAs<LOAD>(load, 2)
&& layer_norm::CanPackAs<STORE>(store, 2)) {
return TryDispatchLaunchRmsNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 2>(
stream, load, store, nrow, ncol, eps, inv_rms, success);
} else {
return TryDispatchLaunchRmsNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 1>(
stream, load, store, nrow, ncol, eps, inv_rms, success);
}
}
template<typename LOAD, typename STORE, typename ComputeType>
GPU(Error_t) TryDispatchLaunchRmsNormBlockSMemImpl(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t nrow, const int64_t ncol,
const double eps, ComputeType* inv_rms,
bool* success) {
return TryDispatchLaunchRmsNormBlockSMemImplPackSize(stream, load, store, nrow, ncol, eps,
inv_rms, success);
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
__global__ void RmsNormBlockUncachedImpl(LOAD load, STORE store, const int nrow, const int ncol,
const double eps, ComputeType* inv_rms) {
assert(ncol % pack_size == 0);
const int num_packs = ncol / pack_size;
for (int row = blockIdx.x; row < nrow; row += gridDim.x) {
ComputeType thread_square_sum = 0;
for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += block_size) {
ComputeType pack[pack_size];
const int col = pack_i * pack_size;
load.template load<pack_size>(pack, row, col);
#pragma unroll
for (int pack_j = 0; pack_j < pack_size; ++pack_j) {
thread_square_sum += pack[pack_j] * pack[pack_j];
}
}
ComputeType row_square_sum =
layer_norm::BlockAllReduce<layer_norm::SumOp, ComputeType, block_size>(thread_square_sum);
ComputeType row_square_mean = layer_norm::Div(row_square_sum, static_cast<ComputeType>(ncol));
ComputeType row_inv_rms = layer_norm::Rsqrt(row_square_mean + static_cast<ComputeType>(eps));
if (threadIdx.x == 0) { inv_rms[row] = row_inv_rms; }
for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += block_size) {
ComputeType pack[pack_size];
const int col = pack_i * pack_size;
load.template load<pack_size>(pack, row, col);
#pragma unroll
for (int pack_j = 0; pack_j < pack_size; ++pack_j) {
pack[pack_j] = pack[pack_j] * row_inv_rms;
}
store.template store<pack_size>(pack, row, col);
}
}
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size>
GPU(Error_t) LaunchRmsNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t nrow, const int64_t ncol, const double eps,
ComputeType* inv_rms) {
constexpr int block_size = 1024;
constexpr int waves = 32;
int grid_dim_x;
{
GPU(Error_t) err = layer_norm::GetNumBlocks(
RmsNormBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size>, block_size, 0,
nrow, waves, &grid_dim_x);
if (err != GPU(Success)) { return err; }
}
RmsNormBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, 0, stream>>>(load, store, nrow, ncol, eps, inv_rms);
return GPU(PeekAtLastError)();
}
template<typename LOAD, typename STORE, typename ComputeType>
GPU(Error_t) DispatchLaunchRmsNormBlockUncachedImplPackSize(GPU(Stream_t) stream, LOAD load,
STORE store, const int64_t nrow,
const int64_t ncol, const double eps,
ComputeType* inv_rms) {
if (ncol % 4 == 0 && layer_norm::CanPackAs<LOAD>(load, 4)
&& layer_norm::CanPackAs<STORE>(store, 4)) {
return LaunchRmsNormBlockUncachedImpl<LOAD, STORE, ComputeType, 4>(stream, load, store, nrow,
ncol, eps, inv_rms);
} else if (ncol % 2 == 0 && layer_norm::CanPackAs<LOAD>(load, 2)
&& layer_norm::CanPackAs<STORE>(store, 2)) {
return LaunchRmsNormBlockUncachedImpl<LOAD, STORE, ComputeType, 2>(stream, load, store, nrow,
ncol, eps, inv_rms);
} else {
return LaunchRmsNormBlockUncachedImpl<LOAD, STORE, ComputeType, 1>(stream, load, store, nrow,
ncol, eps, inv_rms);
}
}
template<typename LOAD, typename STORE, typename ComputeType>
GPU(Error_t) DispatchLaunchRmsNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t nrow, const int64_t ncol,
const double eps, ComputeType* inv_rms) {
return DispatchLaunchRmsNormBlockUncachedImplPackSize(stream, load, store, nrow, ncol, eps,
inv_rms);
}
template<typename LOAD, typename STORE, typename ComputeType>
typename std::enable_if<!std::is_same<ComputeType, double>::value, GPU(Error_t)>::type LaunchRmsNorm(
GPU(Stream_t) stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol,
const double eps, ComputeType* inv_rms) {
if (ncol <= 1024) {
return DispatchLaunchRmsNormWarpImpl(stream, load, store, nrow, ncol, eps, inv_rms);
} else {
bool dispatch_smem_impl_success = false;
{
GPU(Error_t) err = TryDispatchLaunchRmsNormBlockSMemImpl(stream, load, store, nrow, ncol, eps,
inv_rms, &dispatch_smem_impl_success);
if (err != GPU(Success)) { return err; }
}
if (!dispatch_smem_impl_success) {
return DispatchLaunchRmsNormBlockUncachedImpl(stream, load, store, nrow, ncol, eps, inv_rms);
}
return GPU(Success);
}
}
template<typename LOAD, typename STORE, typename ComputeType>
typename std::enable_if<std::is_same<ComputeType, double>::value, GPU(Error_t)>::type LaunchRmsNorm(
GPU(Stream_t) stream, LOAD load, STORE store, const int64_t nrow, const int64_t ncol,
const double eps, ComputeType* inv_rms) {
return DispatchLaunchRmsNormBlockUncachedImpl(stream, load, store, nrow, ncol, eps, inv_rms);
}
template<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,
int rows_per_access>
__global__ void RmsNormGradWarpImpl(const int nrow, const int ncol, LOAD_X load_x, LOAD_DY load_dy,
STORE store, const ComputeType* inv_rms) {
static_assert(max_cols_per_thread % pack_size == 0, "");
static_assert(min_cols_per_thread % pack_size == 0, "");
static_assert(thread_group_width <= kWarpSize, "");
static_assert(kWarpSize % thread_group_width == 0, "");
assert(ncol <= max_cols_per_thread * thread_group_width);
constexpr int max_packs = max_cols_per_thread / pack_size;
constexpr int min_packs = min_cols_per_thread / pack_size;
ComputeType normalized_buf[rows_per_access][max_cols_per_thread];
ComputeType dy_buf[rows_per_access][max_cols_per_thread];
const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;
const int num_global_thread_group = gridDim.x * blockDim.y;
for (int row_i = global_thread_group_id; row_i < nrow; row_i += num_global_thread_group) {
ComputeType sum_stats[rows_per_access];
ComputeType inv_rms_buf[rows_per_access];
#pragma unroll
for (int row_j = 0; row_j < rows_per_access; ++row_j) {
const int global_row = row_i * rows_per_access + row_j;
sum_stats[row_j] = 0;
inv_rms_buf[row_j] = inv_rms[global_row];
ComputeType* row_normalized_buf = normalized_buf[row_j];
ComputeType* row_dy_buf = dy_buf[row_j];
#pragma unroll
for (int pack_i = 0; pack_i < min_packs; ++pack_i) {
const int pack_offset = pack_i * pack_size;
const int global_col = (pack_i * thread_group_width + threadIdx.x) * pack_size;
load_x.template load<pack_size>(row_normalized_buf + pack_offset, global_row, global_col);
load_dy.template load<pack_size>(row_dy_buf + pack_offset, global_row, global_col);
#pragma unroll
for (int pack_j = 0; pack_j < pack_size; ++pack_j) {
const int col = pack_offset + pack_j;
row_normalized_buf[col] = row_normalized_buf[col] * inv_rms_buf[row_j];
sum_stats[row_j] += row_dy_buf[col] * row_normalized_buf[col];
}
}
#pragma unroll
for (int pack_i = min_packs; pack_i < max_packs; ++pack_i) {
const int pack_offset = pack_i * pack_size;
const int global_col = (pack_i * thread_group_width + threadIdx.x) * pack_size;
if (global_col < ncol) {
load_x.template load<pack_size>(row_normalized_buf + pack_offset, global_row, global_col);
load_dy.template load<pack_size>(row_dy_buf + pack_offset, global_row, global_col);
#pragma unroll
for (int pack_j = 0; pack_j < pack_size; ++pack_j) {
const int col = pack_offset + pack_j;
row_normalized_buf[col] = row_normalized_buf[col] * inv_rms_buf[row_j];
sum_stats[row_j] += row_dy_buf[col] * row_normalized_buf[col];
}
}
}
}
ComputeType warp_sum_stats[rows_per_access];
#pragma unroll
for (int row_j = 0; row_j < rows_per_access; ++row_j) {
warp_sum_stats[row_j] =
layer_norm::WarpAllReduce<layer_norm::SumOp, ComputeType, thread_group_width>(
sum_stats[row_j]);
}
#pragma unroll
for (int row_j = 0; row_j < rows_per_access; ++row_j) {
const int global_row = row_i * rows_per_access + row_j;
ComputeType* row_normalized_buf = normalized_buf[row_j];
ComputeType* row_dy_buf = dy_buf[row_j];
#pragma unroll
for (int pack_i = 0; pack_i < min_packs; ++pack_i) {
const int pack_offset = pack_i * pack_size;
const int global_col = (pack_i * thread_group_width + threadIdx.x) * pack_size;
for (int pack_j = 0; pack_j < pack_size; ++pack_j) {
const int col = pack_offset + pack_j;
const ComputeType norm_val =
layer_norm::Div(row_normalized_buf[col], static_cast<ComputeType>(ncol));
row_dy_buf[col] =
(row_dy_buf[col] - norm_val * warp_sum_stats[row_j]) * inv_rms_buf[row_j];
}
store.template store<pack_size>(row_dy_buf + pack_offset, global_row, global_col);
}
#pragma unroll
for (int pack_i = min_packs; pack_i < max_packs; ++pack_i) {
const int pack_offset = pack_i * pack_size;
const int global_col = (pack_i * thread_group_width + threadIdx.x) * pack_size;
if (global_col < ncol) {
for (int pack_j = 0; pack_j < pack_size; ++pack_j) {
const int col = pack_offset + pack_j;
const ComputeType norm_val =
layer_norm::Div(row_normalized_buf[col], static_cast<ComputeType>(ncol));
row_dy_buf[col] =
(row_dy_buf[col] - norm_val * warp_sum_stats[row_j]) * inv_rms_buf[row_j];
}
store.template store<pack_size>(row_dy_buf + pack_offset, global_row, global_col);
}
}
}
}
}
template<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int max_cols_per_thread, int min_cols_per_thread, int thread_group_width,
int rows_per_access>
GPU(Error_t) LaunchRmsNormGradWarpImpl(GPU(Stream_t) stream, const int nrow, const int ncol,
LOAD_X load_x, LOAD_DY load_dy, STORE store,
const ComputeType* inv_rms) {
constexpr int block_size = 128;
constexpr int waves = 32;
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
const int64_t num_blocks =
(nrow / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;
int grid_dim_x;
{
GPU(Error_t) err = layer_norm::GetNumBlocks(
RmsNormGradWarpImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, max_cols_per_thread,
min_cols_per_thread, thread_group_width, rows_per_access>,
block_size, 0, num_blocks, waves, &grid_dim_x);
if (err != GPU(Success)) { return err; }
}
dim3 block_dim(thread_group_width, thread_groups_per_block);
RmsNormGradWarpImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, max_cols_per_thread,
min_cols_per_thread, thread_group_width, rows_per_access>
<<<grid_dim_x, block_dim, 0, stream>>>(nrow, ncol, load_x, load_dy, store, inv_rms);
return GPU(PeekAtLastError)();
}
template<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size>
typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLaunchRmsNormGradWarpImplCols(
GPU(Stream_t) stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy,
STORE store, const ComputeType* inv_rms) {
if (ncol <= 0) { return GPU(ErrorInvalidValue); }
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (ncol <= (thread_group_width)*pack_size) { \
if (nrow % 2 == 0) { \
return LaunchRmsNormGradWarpImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, pack_size, \
0, thread_group_width, 2>(stream, nrow, ncol, load_x, \
load_dy, store, inv_rms); \
} else { \
return LaunchRmsNormGradWarpImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, pack_size, \
0, thread_group_width, 1>(stream, nrow, ncol, load_x, \
load_dy, store, inv_rms); \
} \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (ncol <= (max_col)*kWarpSize) { \
return LaunchRmsNormGradWarpImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, max_col, \
min_col, kWarpSize, 1>(stream, nrow, ncol, load_x, load_dy, \
store, inv_rms); \
}
DEFINE_ONE_ELIF(2, 1)
DEFINE_ONE_ELIF(4, 2)
DEFINE_ONE_ELIF(8, 4)
DEFINE_ONE_ELIF(12, 8)
DEFINE_ONE_ELIF(16, 12)
DEFINE_ONE_ELIF(20, 16)
DEFINE_ONE_ELIF(24, 20)
DEFINE_ONE_ELIF(28, 24)
DEFINE_ONE_ELIF(32, 28)
#undef DEFINE_ONE_ELIF
else {
return GPU(ErrorInvalidValue);
}
}
template<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType>
GPU(Error_t) DispatchLaunchRmsNormGradWarpImplPackSize(GPU(Stream_t) stream, const int64_t nrow,
const int64_t ncol, LOAD_X load_x,
LOAD_DY load_dy, STORE store,
const ComputeType* inv_rms) {
return DispatchLaunchRmsNormGradWarpImplCols<LOAD_X, LOAD_DY, STORE, ComputeType, 1>(
stream, nrow, ncol, load_x, load_dy, store, inv_rms);
}
template<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int block_size>
__global__ void RmsNormGradBlockSMemImpl(const int nrow, const int ncol, LOAD_X load_x,
LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) {
extern __shared__ __align__(sizeof(double)) unsigned char dyn_smem[];
// dynamic shared memory for caching x and dy
auto* normalized_buf = reinterpret_cast<ComputeType*>(dyn_smem);
auto* dy_buf = normalized_buf + ncol;
assert(ncol % pack_size == 0);
const int num_packs = ncol / pack_size;
for (int row = blockIdx.x; row < nrow; row += gridDim.x) {
ComputeType sum_stats = 0;
const ComputeType inv_rms_val = inv_rms[row];
for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += blockDim.x) {
ComputeType x_pack[pack_size];
ComputeType dy_pack[pack_size];
const int pack_offset = pack_i * pack_size;
load_x.template load<pack_size>(x_pack, row, pack_offset);
load_dy.template load<pack_size>(dy_pack, row, pack_offset);
#pragma unroll
for (int pack_j = 0; pack_j < pack_size; ++pack_j) {
const int col = pack_offset + pack_j;
normalized_buf[col] = x_pack[pack_j] * inv_rms_val;
dy_buf[col] = dy_pack[pack_j];
sum_stats += dy_buf[col] * normalized_buf[col];
}
}
const ComputeType row_sum_stats =
layer_norm::BlockAllReduce<layer_norm::SumOp, ComputeType, block_size>(sum_stats);
for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += blockDim.x) {
ComputeType pack[pack_size];
const int pack_offset = pack_i * pack_size;
#pragma unroll
for (int pack_j = 0; pack_j < pack_size; ++pack_j) {
const int col = pack_offset + pack_j;
const ComputeType norm_val =
layer_norm::Div(normalized_buf[col], static_cast<ComputeType>(ncol));
pack[pack_j] = (dy_buf[col] - norm_val * row_sum_stats) * inv_rms_val;
}
store.template store<pack_size>(pack, row, pack_offset);
}
}
}
template<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int block_size>
GPU(Error_t) LaunchRmsNormGradBlockSMemImpl(GPU(Stream_t) stream, const int64_t nrow,
const int64_t ncol, const size_t smem_size,
LOAD_X load_x, LOAD_DY load_dy, STORE store,
const ComputeType* inv_rms) {
constexpr int waves = 32;
int grid_dim_x;
{
GPU(Error_t) err = layer_norm::GetNumBlocks(
RmsNormGradBlockSMemImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, block_size>,
block_size, smem_size, nrow, waves, &grid_dim_x);
if (err != GPU(Success)) { return err; }
}
RmsNormGradBlockSMemImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, smem_size, stream>>>(
static_cast<int>(nrow), static_cast<int>(ncol), load_x, load_dy, store, inv_rms);
return GPU(PeekAtLastError)();
}
template<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size>
GPU(Error_t) TryDispatchLaunchRmsNormGradBlockSMemImplBlockSize(
GPU(Stream_t) stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy,
STORE store, const ComputeType* inv_rms, bool* success) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
const size_t smem_size = ncol * sizeof(ComputeType) * 2; // ncol * 2 for caching x and dy both
int max_active_blocks = 0;
int num_blocks = 0;
#define SELECT_BLOCK_SIZE_CONF(block_size_conf) \
{ \
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)( \
&num_blocks, \
RmsNormGradBlockSMemImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, block_size_conf>, \
block_size_conf, smem_size); \
if (err != GPU(Success)) { return err; } \
if (max_active_blocks == 0) { \
if (num_blocks <= max_active_blocks) { \
*success = false; \
return GPU(Success); \
} \
max_active_blocks = num_blocks; \
} else { \
if (num_blocks == max_active_blocks) { \
*success = true; \
return LaunchRmsNormGradBlockSMemImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, \
block_size_conf>(stream, nrow, ncol, smem_size, \
load_x, load_dy, store, inv_rms); \
} \
} \
}
SELECT_BLOCK_SIZE_CONF(block_size_conf_1)
SELECT_BLOCK_SIZE_CONF(block_size_conf_4)
SELECT_BLOCK_SIZE_CONF(block_size_conf_3)
SELECT_BLOCK_SIZE_CONF(block_size_conf_2)
#undef SELECT_BLOCK_SIZE_CONF
*success = true;
return LaunchRmsNormGradBlockSMemImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_1>(stream, nrow, ncol, smem_size, load_x,
load_dy, store, inv_rms);
}
template<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType>
GPU(Error_t) TryDispatchLaunchRmsNormGradBlockSMemImplPackSize(
GPU(Stream_t) stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy,
STORE store, const ComputeType* inv_rms, bool* success) {
if (ncol % 2 == 0 && layer_norm::CanPackAs<LOAD_X>(load_x, 2)
&& layer_norm::CanPackAs<LOAD_DY>(load_dy, 2) && layer_norm::CanPackAs<STORE>(store, 2)) {
return TryDispatchLaunchRmsNormGradBlockSMemImplBlockSize<LOAD_X, LOAD_DY, STORE, ComputeType,
2>(stream, nrow, ncol, load_x,
load_dy, store, inv_rms, success);
} else {
return TryDispatchLaunchRmsNormGradBlockSMemImplBlockSize<LOAD_X, LOAD_DY, STORE, ComputeType,
1>(stream, nrow, ncol, load_x,
load_dy, store, inv_rms, success);
}
}
template<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int block_size>
__global__ void RmsNormGradBlockUncachedImpl(const int nrow, const int ncol, LOAD_X load_x,
LOAD_DY load_dy, STORE store,
const ComputeType* inv_rms) {
assert(ncol % pack_size == 0);
const int num_packs = ncol / pack_size;
for (int row = blockIdx.x; row < nrow; row += gridDim.x) {
const ComputeType inv_rms_val = inv_rms[row];
ComputeType sum_stats = 0;
for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += blockDim.x) {
ComputeType x_pack[pack_size];
ComputeType dy_pack[pack_size];
const int pack_offset = pack_i * pack_size;
load_x.template load<pack_size>(x_pack, row, pack_offset);
load_dy.template load<pack_size>(dy_pack, row, pack_offset);
#pragma unroll
for (int pack_j = 0; pack_j < pack_size; ++pack_j) {
sum_stats += dy_pack[pack_j] * x_pack[pack_j] * inv_rms_val;
}
}
const ComputeType row_sum_stats =
layer_norm::BlockAllReduce<layer_norm::SumOp, ComputeType, block_size>(sum_stats);
for (int pack_i = threadIdx.x; pack_i < num_packs; pack_i += blockDim.x) {
ComputeType x_pack[pack_size];
ComputeType dy_pack[pack_size];
const int pack_offset = pack_i * pack_size;
load_x.template load<pack_size>(x_pack, row, pack_offset);
load_dy.template load<pack_size>(dy_pack, row, pack_offset);
#pragma unroll
for (int pack_j = 0; pack_j < pack_size; ++pack_j) {
const ComputeType norm_val =
layer_norm::Div(x_pack[pack_j] * inv_rms_val, static_cast<ComputeType>(ncol));
dy_pack[pack_j] = (dy_pack[pack_j] - norm_val * row_sum_stats) * inv_rms_val;
}
store.template store<pack_size>(dy_pack, row, pack_offset);
}
}
}
template<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int block_size>
GPU(Error_t) LaunchRmsNormGradBlockUncachedImpl(GPU(Stream_t) stream, const int64_t nrow,
const int64_t ncol, LOAD_X load_x, LOAD_DY load_dy,
STORE store, const ComputeType* inv_rms) {
constexpr int waves = 32;
int grid_dim_x;
{
GPU(Error_t) err = layer_norm::GetNumBlocks(
RmsNormGradBlockUncachedImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, block_size>,
block_size, 0, nrow, waves, &grid_dim_x);
if (err != GPU(Success)) { return err; }
}
RmsNormGradBlockUncachedImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, 0, stream>>>(nrow, ncol, load_x, load_dy, store, inv_rms);
return GPU(PeekAtLastError)();
}
template<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size>
GPU(Error_t) DispatchLaunchRmsNormGradBlockUncachedImplBlockSize(GPU(Stream_t) stream,
const int64_t nrow,
const int64_t ncol, LOAD_X load_x,
LOAD_DY load_dy, STORE store,
const ComputeType* inv_rms) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
int max_active_blocks = 0;
#define SELECT_BLOCK_SIZE_CONF(block_size_conf) \
{ \
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)( \
&max_active_blocks, \
RmsNormGradBlockUncachedImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, \
block_size_conf>, \
block_size_conf, 0); \
if (err != GPU(Success)) { return err; } \
if (max_active_blocks > 0) { \
return LaunchRmsNormGradBlockUncachedImpl<LOAD_X, LOAD_DY, STORE, ComputeType, pack_size, \
block_size_conf>(stream, nrow, ncol, load_x, \
load_dy, store, inv_rms); \
} \
}
SELECT_BLOCK_SIZE_CONF(block_size_conf_4)
SELECT_BLOCK_SIZE_CONF(block_size_conf_3)
SELECT_BLOCK_SIZE_CONF(block_size_conf_2)
SELECT_BLOCK_SIZE_CONF(block_size_conf_1)
#undef SELECT_BLOCK_SIZE_CONF
return GPU(ErrorInvalidValue);
}
template<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType>
GPU(Error_t) DispatchLaunchRmsNormGradBlockUncachedImplPackSize(GPU(Stream_t) stream,
const int64_t nrow,
const int64_t ncol, LOAD_X load_x,
LOAD_DY load_dy, STORE store,
const ComputeType* inv_rms) {
if (ncol % 2 == 0 && layer_norm::CanPackAs<LOAD_X>(load_x, 2)
&& layer_norm::CanPackAs<LOAD_DY>(load_dy, 2) && layer_norm::CanPackAs<STORE>(store, 2)
&& ncol > kWarpSize) {
return DispatchLaunchRmsNormGradBlockUncachedImplBlockSize<LOAD_X, LOAD_DY, STORE, ComputeType,
2>(stream, nrow, ncol, load_x,
load_dy, store, inv_rms);
} else {
return DispatchLaunchRmsNormGradBlockUncachedImplBlockSize<LOAD_X, LOAD_DY, STORE, ComputeType,
1>(stream, nrow, ncol, load_x,
load_dy, store, inv_rms);
}
}
template<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType>
typename std::enable_if<!std::is_same<ComputeType, double>::value, GPU(Error_t)>::type
LaunchRmsNormGrad(GPU(Stream_t) stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x,
LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) {
if (ncol <= 1024) {
return DispatchLaunchRmsNormGradWarpImplPackSize(stream, nrow, ncol, load_x, load_dy, store,
inv_rms);
} else {
bool dispatch_smem_impl_success = false;
{
GPU(Error_t) err = TryDispatchLaunchRmsNormGradBlockSMemImplPackSize(
stream, nrow, ncol, load_x, load_dy, store, inv_rms, &dispatch_smem_impl_success);
if (err != GPU(Success)) { return err; }
}
if (!dispatch_smem_impl_success) {
return DispatchLaunchRmsNormGradBlockUncachedImplPackSize(stream, nrow, ncol, load_x, load_dy,
store, inv_rms);
}
return GPU(Success);
}
}
template<typename LOAD_X, typename LOAD_DY, typename STORE, typename ComputeType>
typename std::enable_if<std::is_same<ComputeType, double>::value, GPU(Error_t)>::type
LaunchRmsNormGrad(GPU(Stream_t) stream, const int64_t nrow, const int64_t ncol, LOAD_X load_x,
LOAD_DY load_dy, STORE store, const ComputeType* inv_rms) {
return DispatchLaunchRmsNormGradBlockUncachedImplPackSize(stream, nrow, ncol, load_x, load_dy,
store, inv_rms);
}
template<int nproc_per_thread, typename T, typename ComputeType>
__global__ void RmsNormParamGrad(int nrow, int ncol, const T* __restrict__ dy,
const T* __restrict__ x, const ComputeType* __restrict__ inv_rms,
T* __restrict__ b_weight_grad) {
__shared__ ComputeType dweight[kWarpSize][kWarpSize + 1];
ComputeType dweight_sum[nproc_per_thread];
#pragma unroll
for (int i = 0; i < nproc_per_thread; ++i) { dweight_sum[i] = 0; }
const int col = blockIdx.x * blockDim.x + threadIdx.x;
if (col < ncol) {
// a wave for one traverse (when nrow > warp_size * grad_dim_y)
for (int j = blockIdx.y * kWarpSize + threadIdx.y; j < nrow; j += kWarpSize * gridDim.y) {
#pragma unroll
for (int i = 0; i < nproc_per_thread; ++i) {
int row = j + i * blockDim.y;
if (row < nrow) {
int offset = row * ncol + col;
const ComputeType dy_val = static_cast<ComputeType>(dy[offset]);
const ComputeType x_val = static_cast<ComputeType>(x[offset]);
const ComputeType inv_rms_val = inv_rms[row];
// collect dx from waves
dweight_sum[i] += dy_val * x_val * inv_rms_val;
}
}
}
}
// broadcast sum to the nproc_per_thread number rows
// each warp process the nproc_per_thread number rows of smem
#pragma unroll
for (int i = 0; i < nproc_per_thread; ++i) {
dweight[i * blockDim.y + threadIdx.y][threadIdx.x] = dweight_sum[i];
}
__syncthreads();
// transpose access for leveraging warp to reduce rows in a block
#pragma unroll
for (int i = 0; i < nproc_per_thread; ++i) {
// the first col of block threads is for storing the reduced sum of rows,
// and each first col thread is writing the nproc_per_thread number cols of output
const int row_in_block = threadIdx.y + i * blockDim.y;
const int col = blockIdx.x * blockDim.x + row_in_block;
if (col < ncol) {
// each warp process a col in which reduce sum all rows
ComputeType dweight_val = dweight[threadIdx.x][row_in_block];
ComputeType global_dweight = WarpReduceSum<ComputeType>(dweight_val);
if (threadIdx.x == 0) {
const int offset = blockIdx.y * ncol + col;
b_weight_grad[offset] = global_dweight;
}
}
}
}
template<int nproc_per_thread, typename T>
GPU(Error_t) GetGrid2Dim(const int64_t nrow, const int64_t ncol, int block_dim_x, int block_dim_y,
int* grid_dim_x, int* grid_dim_y) {
const int tile_size = block_dim_x;
if (nproc_per_thread * block_dim_y != tile_size) { return GPU(ErrorInvalidValue); }
*grid_dim_x = (ncol + tile_size - 1) / tile_size;
const int num_blocks_y = (nrow + tile_size - 1) / tile_size;
using ComputeType = typename layer_norm::DefaultComputeType<T>::type;
GPU(Error_t) err = layer_norm::GetNumBlocks(RmsNormParamGrad<nproc_per_thread, T, ComputeType>,
block_dim_x * block_dim_y, /*dynamic_smem_size*/ 0,
num_blocks_y, /*waves*/ 1, grid_dim_y);
if (err != GPU(Success)) { return err; }
return GPU(Success);
}
} // namespace rms_norm
} // namespace cuda
} // namespace oneflow
#endif // ONEFLOW_CORE_CUDA_RMS_NORM_H_
......@@ -17,10 +17,15 @@ limitations under the License.
#ifndef ONEFLOW_CORE_CUDA_SOFTMAX_H_
#define ONEFLOW_CORE_CUDA_SOFTMAX_H_
#ifdef WITH_ROCM
#include "hip/hip_runtime.h"
#include <hipcub/hipcub.hpp>
#else
#include <cuda.h>
#include <cub/cub.cuh>
#include <math_constants.h>
#endif
#include <assert.h>
#include <cuda.h>
#if CUDA_VERSION >= 11000
#include <cuda_bf16.h>
......@@ -32,7 +37,11 @@ namespace cuda {
namespace softmax {
#ifdef WITH_ROCM
constexpr int kWarpSize = 64;
#else
constexpr int kWarpSize = 32;
#endif
template<typename T>
struct SumOp {
......@@ -47,14 +56,22 @@ struct MaxOp {
template<template<typename> class ReductionOp, typename T, int thread_group_width = kWarpSize>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
#ifdef WITH_ROCM
val = ReductionOp<T>()(val, __shfl_xor(val, mask));
#else
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
#endif
}
return val;
}
template<template<typename> class ReductionOp, typename T, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
#ifdef WITH_ROCM
typedef hipcub::BlockReduce<T, block_size> BlockReduce;
#else
typedef cub::BlockReduce<T, block_size> BlockReduce;
#endif
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
......@@ -68,12 +85,20 @@ __inline__ __device__ T Inf();
template<>
__inline__ __device__ float Inf<float>() {
#ifdef WITH_ROCM
return __int_as_float(0x7f800000U);
#else
return CUDART_INF_F;
#endif
}
template<>
__inline__ __device__ double Inf<double>() {
#ifdef WITH_ROCM
return __longlong_as_double(0x7ff0000000000000ULL);
#else
return CUDART_INF;
#endif
}
template<typename T>
......@@ -126,26 +151,26 @@ __inline__ __device__ double Log<double>(double x) {
return log(x);
}
inline cudaError_t GetNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves,
inline GPU(Error_t) GetNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves,
int* num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GPU(GetDevice)(&dev);
if (err != GPU(Success)) { return err; }
}
int sm_count;
{
cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GPU(DeviceGetAttribute)(&sm_count, GPUMultiProcessorCount, dev);
if (err != GPU(Success)) { return err; }
}
int tpm;
{
cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GPU(DeviceGetAttribute)(&tpm, GPUMaxThreadsPerMultiProcessor, dev);
if (err != GPU(Success)) { return err; }
}
*num_blocks =
std::max<int>(1, std::min<int64_t>(max_blocks, sm_count * tpm / block_size * waves));
return cudaSuccess;
return GPU(Success);
}
template<typename T>
......@@ -272,7 +297,7 @@ __global__ void SoftmaxWarpImpl(LOAD load, STORE store, const int64_t rows, cons
row_buf[i] -= warp_max[row_id];
thread_sum[row_id] += Exp(row_buf[i]);
} else {
__trap();
TRAP();
}
}
}
......@@ -291,7 +316,7 @@ __global__ void SoftmaxWarpImpl(LOAD load, STORE store, const int64_t rows, cons
} else if (algorithm == Algorithm::kLogSoftmax) {
row_buf[i] -= Log(warp_sum[row_id]);
} else {
__trap();
TRAP();
}
}
#pragma unroll
......@@ -307,7 +332,7 @@ __global__ void SoftmaxWarpImpl(LOAD load, STORE store, const int64_t rows, cons
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access, bool padding, Algorithm algorithm>
inline cudaError_t LaunchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store,
inline GPU(Error_t) LaunchSoftmaxWarpImpl(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols) {
constexpr int block_size = 128;
constexpr int waves = 32;
......@@ -318,18 +343,18 @@ inline cudaError_t LaunchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE s
(rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x);
if (err != GPU(Success)) { return err; }
}
SoftmaxWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread, thread_group_width,
rows_per_access, padding, algorithm>
<<<grid_dim_x, block_dim, 0, stream>>>(load, store, rows, cols);
return cudaPeekAtLastError();
return GPU(PeekAtLastError)();
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access, Algorithm algorithm>
inline cudaError_t DispatchSoftmaxWarpImplPadding(cudaStream_t stream, LOAD load, STORE store,
inline GPU(Error_t) DispatchSoftmaxWarpImplPadding(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols) {
if (cols == cols_per_thread * thread_group_width) {
return LaunchSoftmaxWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,
......@@ -343,9 +368,9 @@ inline cudaError_t DispatchSoftmaxWarpImplPadding(cudaStream_t stream, LOAD load
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>
typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchSoftmaxWarpImplCols(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) {
if (cols <= 0) { return cudaErrorInvalidValue; }
typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchSoftmaxWarpImplCols(
GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) {
if (cols <= 0) { return GPU(ErrorInvalidValue); }
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
......@@ -403,14 +428,14 @@ typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchSoftmaxWarpIm
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
return GPU(ErrorInvalidValue);
}
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>
typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchSoftmaxWarpImplCols(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) {
if (cols <= 0) { return cudaErrorInvalidValue; }
typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchSoftmaxWarpImplCols(
GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) {
if (cols <= 0) { return GPU(ErrorInvalidValue); }
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
......@@ -452,13 +477,13 @@ typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchSoftmaxWarpIm
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
return GPU(ErrorInvalidValue);
}
}
template<typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
struct DispatchSoftmaxWarpImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
GPU(Error_t) operator()(GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
if (cols % 2 == 0) {
return DispatchSoftmaxWarpImplCols<LOAD, STORE, ComputeType, 2, algorithm>(stream, load,
......@@ -471,7 +496,7 @@ struct DispatchSoftmaxWarpImplPackSize {
};
template<typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
inline cudaError_t DispatchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store,
inline GPU(Error_t) DispatchSoftmaxWarpImpl(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols) {
return DispatchSoftmaxWarpImplPackSize<LOAD, STORE, ComputeType, algorithm>()(stream, load, store,
rows, cols);
......@@ -520,7 +545,7 @@ __global__ void SoftmaxBlockSMemImpl(LOAD load, STORE store, const int64_t rows,
} else if (algorithm == Algorithm::kLogSoftmax) {
pack[i] = buf[i * num_packs + pack_id] - Log(row_sum);
} else {
__trap();
TRAP();
}
}
store.template store<pack_size>(pack, row, pack_id * pack_size);
......@@ -530,21 +555,21 @@ __global__ void SoftmaxBlockSMemImpl(LOAD load, STORE store, const int64_t rows,
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size,
Algorithm algorithm>
inline cudaError_t LaunchSoftmaxBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, int smem,
inline GPU(Error_t) LaunchSoftmaxBlockSMemImpl(GPU(Stream_t) stream, LOAD load, STORE store, int smem,
const int64_t rows, const int64_t cols) {
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != GPU(Success)) { return err; }
}
SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size, algorithm>
<<<grid_dim_x, block_size, smem, stream>>>(load, store, rows, cols);
return cudaPeekAtLastError();
return GPU(PeekAtLastError)();
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>
inline cudaError_t TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream, LOAD load,
inline GPU(Error_t) TryDispatchSoftmaxBlockSMemImplBlockSize(GPU(Stream_t) stream, LOAD load,
STORE store, const int64_t rows,
const int64_t cols, bool* success) {
constexpr int block_size_conf_1 = 128;
......@@ -554,23 +579,23 @@ inline cudaError_t TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream,
const size_t smem = cols * sizeof(ComputeType);
int max_active_blocks_conf_1;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_1,
SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1, algorithm>,
block_size_conf_1, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_1 <= 0) {
*success = false;
return cudaSuccess;
return GPU(Success);
}
int max_active_blocks_conf_4;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_4,
SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4, algorithm>,
block_size_conf_4, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_4 == max_active_blocks_conf_1) {
*success = true;
......@@ -579,11 +604,11 @@ inline cudaError_t TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream,
}
int max_active_blocks_conf_3;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_3,
SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3, algorithm>,
block_size_conf_3, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_3 == max_active_blocks_conf_1) {
*success = true;
......@@ -592,11 +617,11 @@ inline cudaError_t TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream,
}
int max_active_blocks_conf_2;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_2,
SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2, algorithm>,
block_size_conf_2, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {
*success = true;
......@@ -610,7 +635,7 @@ inline cudaError_t TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream,
template<typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
struct TryDispatchSoftmaxBlockSMemImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
GPU(Error_t) operator()(GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, bool* success) {
if (cols % 2 == 0) {
return TryDispatchSoftmaxBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 2, algorithm>(
......@@ -623,7 +648,7 @@ struct TryDispatchSoftmaxBlockSMemImplPackSize {
};
template<typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
inline cudaError_t TryDispatchSoftmaxBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store,
inline GPU(Error_t) TryDispatchSoftmaxBlockSMemImpl(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
bool* success) {
return TryDispatchSoftmaxBlockSMemImplPackSize<LOAD, STORE, ComputeType, algorithm>()(
......@@ -664,7 +689,7 @@ __global__ void SoftmaxBlockUncachedImpl(LOAD load, STORE store, const int64_t r
} else if (algorithm == Algorithm::kLogSoftmax) {
pack[i] = (pack[i] - row_max) - Log(row_sum);
} else {
__trap();
TRAP();
}
}
store.template store<pack_size>(pack, row, pack_id * pack_size);
......@@ -673,23 +698,23 @@ __global__ void SoftmaxBlockUncachedImpl(LOAD load, STORE store, const int64_t r
}
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>
inline cudaError_t LaunchSoftmaxBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,
inline GPU(Error_t) LaunchSoftmaxBlockUncachedImpl(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols) {
constexpr int block_size = 1024;
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != GPU(Success)) { return err; }
}
SoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size, algorithm>
<<<grid_dim_x, block_size, 0, stream>>>(load, store, rows, cols);
return cudaPeekAtLastError();
return GPU(PeekAtLastError)();
}
template<typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
struct DispatchSoftmaxBlockUncachedImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
GPU(Error_t) operator()(GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
if (cols % 2 == 0) {
return LaunchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, 2, algorithm>(
......@@ -702,15 +727,15 @@ struct DispatchSoftmaxBlockUncachedImplPackSize {
};
template<typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
inline cudaError_t DispatchSoftmaxBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,
inline GPU(Error_t) DispatchSoftmaxBlockUncachedImpl(GPU(Stream_t) stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols) {
return DispatchSoftmaxBlockUncachedImplPackSize<LOAD, STORE, ComputeType, algorithm>()(
stream, load, store, rows, cols);
}
template<typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, GPU(Error_t)>::type
DispatchSoftmax(GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
if (cols < 1024) {
return DispatchSoftmaxWarpImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(
......@@ -718,30 +743,30 @@ DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err =
GPU(Error_t) err =
TryDispatchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(
stream, load, store, rows, cols, &dispatch_smem_impl_success);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (!dispatch_smem_impl_success) {
return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(
stream, load, store, rows, cols);
}
return cudaSuccess;
return GPU(Success);
}
}
template<typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
inline typename std::enable_if<std::is_same<ComputeType, double>::value, GPU(Error_t)>::type
DispatchSoftmax(GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(
stream, load, store, rows, cols);
}
template<typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, GPU(Error_t)>::type
DispatchLogSoftmax(GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
if (cols <= 1024) {
return DispatchSoftmaxWarpImpl<LOAD, STORE, ComputeType, Algorithm::kLogSoftmax>(
......@@ -749,22 +774,22 @@ DispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t ro
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err =
GPU(Error_t) err =
TryDispatchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, Algorithm::kLogSoftmax>(
stream, load, store, rows, cols, &dispatch_smem_impl_success);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (!dispatch_smem_impl_success) {
return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, Algorithm::kLogSoftmax>(
stream, load, store, rows, cols);
}
return cudaSuccess;
return GPU(Success);
}
}
template<typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
inline typename std::enable_if<std::is_same<ComputeType, double>::value, GPU(Error_t)>::type
DispatchLogSoftmax(GPU(Stream_t) stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, Algorithm::kLogSoftmax>(
stream, load, store, rows, cols);
......@@ -807,7 +832,7 @@ __global__ void SoftmaxGradWarpImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store,
} else if (algorithm == Algorithm::kLogSoftmax) {
thread_sum[row_id] += row_dy_buf[pack_offset + i];
} else {
__trap();
TRAP();
}
}
}
......@@ -834,7 +859,7 @@ __global__ void SoftmaxGradWarpImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store,
} else if (algorithm == Algorithm::kLogSoftmax) {
row_dy_buf[pack_offset + i] -= Exp(row_y_buf[pack_offset + i]) * warp_sum[row_id];
} else {
__trap();
TRAP();
}
}
store.template store<pack_size>(row_dy_buf + pack_offset, row + row_id, col);
......@@ -847,7 +872,7 @@ __global__ void SoftmaxGradWarpImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store,
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int cols_per_thread, int thread_group_width, int rows_per_access, bool padding,
Algorithm algorithm>
inline cudaError_t LaunchSoftmaxGradWarpImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy,
inline GPU(Error_t) LaunchSoftmaxGradWarpImpl(GPU(Stream_t) stream, LOAD_Y load_y, LOAD_DY load_dy,
STORE store, const int64_t rows, const int64_t cols) {
constexpr int block_size = 128;
constexpr int waves = 32;
......@@ -858,18 +883,18 @@ inline cudaError_t LaunchSoftmaxGradWarpImpl(cudaStream_t stream, LOAD_Y load_y,
(rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x);
if (err != GPU(Success)) { return err; }
}
SoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, padding, algorithm>
<<<grid_dim_x, block_dim, 0, stream>>>(load_y, load_dy, store, rows, cols);
return cudaPeekAtLastError();
return GPU(PeekAtLastError)();
}
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int cols_per_thread, int thread_group_width, int rows_per_access, Algorithm algorithm>
inline cudaError_t DispatchSoftmaxGradWarpImplPadding(cudaStream_t stream, LOAD_Y load_y,
inline GPU(Error_t) DispatchSoftmaxGradWarpImplPadding(GPU(Stream_t) stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols == cols_per_thread * thread_group_width) {
......@@ -885,10 +910,10 @@ inline cudaError_t DispatchSoftmaxGradWarpImplPadding(cudaStream_t stream, LOAD_
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
Algorithm algorithm>
typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchSoftmaxGradWarpImplCols(
cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows,
typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchSoftmaxGradWarpImplCols(
GPU(Stream_t) stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows,
const int64_t cols) {
if (cols <= 0) { return cudaErrorInvalidValue; }
if (cols <= 0) { return GPU(ErrorInvalidValue); }
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
......@@ -947,16 +972,16 @@ typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchSoftmaxGradWa
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
return GPU(ErrorInvalidValue);
}
}
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
Algorithm algorithm>
typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchSoftmaxGradWarpImplCols(
cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows,
typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchSoftmaxGradWarpImplCols(
GPU(Stream_t) stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows,
const int64_t cols) {
if (cols <= 0) { return cudaErrorInvalidValue; }
if (cols <= 0) { return GPU(ErrorInvalidValue); }
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
......@@ -999,14 +1024,14 @@ typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchSoftmaxGradWa
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
return GPU(ErrorInvalidValue);
}
}
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
struct DispatchSoftmaxGradWarpImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
GPU(Error_t) operator()(GPU(Stream_t) stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols % 2 == 0) {
return DispatchSoftmaxGradWarpImplCols<LOAD_Y, LOAD_DY, STORE, ComputeType, 2, algorithm>(
......@@ -1020,7 +1045,7 @@ struct DispatchSoftmaxGradWarpImplPackSize {
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
inline cudaError_t DispatchSoftmaxGradWarpImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy,
inline GPU(Error_t) DispatchSoftmaxGradWarpImpl(GPU(Stream_t) stream, LOAD_Y load_y, LOAD_DY load_dy,
STORE store, const int64_t rows,
const int64_t cols) {
return DispatchSoftmaxGradWarpImplPackSize<LOAD_Y, LOAD_DY, STORE, ComputeType, algorithm>()(
......@@ -1053,7 +1078,7 @@ __global__ void SoftmaxGradBlockSMemImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE s
} else if (algorithm == Algorithm::kLogSoftmax) {
thread_sum += dy_pack[i];
} else {
__trap();
TRAP();
}
}
}
......@@ -1067,7 +1092,7 @@ __global__ void SoftmaxGradBlockSMemImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE s
} else if (algorithm == Algorithm::kLogSoftmax) {
pack[i] = dy_buf[i * num_packs + pack_id] - Exp(y_buf[i * num_packs + pack_id]) * row_sum;
} else {
__trap();
TRAP();
}
}
store.template store<pack_size>(pack, row, pack_id * pack_size);
......@@ -1077,23 +1102,23 @@ __global__ void SoftmaxGradBlockSMemImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE s
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int block_size, Algorithm algorithm>
inline cudaError_t LaunchSoftmaxGradBlockSMemImpl(cudaStream_t stream, LOAD_Y load_y,
inline GPU(Error_t) LaunchSoftmaxGradBlockSMemImpl(GPU(Stream_t) stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store, int smem,
const int64_t rows, const int64_t cols) {
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != GPU(Success)) { return err; }
}
SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size, algorithm>
<<<grid_dim_x, block_size, smem, stream>>>(load_y, load_dy, store, rows, cols);
return cudaPeekAtLastError();
return GPU(PeekAtLastError)();
}
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
Algorithm algorithm>
inline cudaError_t TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t stream, LOAD_Y load_y,
inline GPU(Error_t) TryDispatchSoftmaxGradBlockSMemImplBlockSize(GPU(Stream_t) stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store,
const int64_t rows,
const int64_t cols, bool* success) {
......@@ -1104,25 +1129,25 @@ inline cudaError_t TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t str
const size_t smem = cols * sizeof(ComputeType) * 2;
int max_active_blocks_conf_1;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_1,
SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size_conf_1,
algorithm>,
block_size_conf_1, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_1 <= 0) {
*success = false;
return cudaSuccess;
return GPU(Success);
}
int max_active_blocks_conf_4;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_4,
SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size_conf_4,
algorithm>,
block_size_conf_4, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_4 == max_active_blocks_conf_1) {
*success = true;
......@@ -1132,12 +1157,12 @@ inline cudaError_t TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t str
}
int max_active_blocks_conf_3;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_3,
SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size_conf_3,
algorithm>,
block_size_conf_3, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_3 == max_active_blocks_conf_1) {
*success = true;
......@@ -1147,12 +1172,12 @@ inline cudaError_t TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t str
}
int max_active_blocks_conf_2;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GPU(Error_t) err = GPU(OccupancyMaxActiveBlocksPerMultiprocessor)(
&max_active_blocks_conf_2,
SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size_conf_2,
algorithm>,
block_size_conf_2, smem);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {
*success = true;
......@@ -1169,7 +1194,7 @@ inline cudaError_t TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t str
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
struct TryDispatchSoftmaxGradBlockSMemImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
GPU(Error_t) operator()(GPU(Stream_t) stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols, bool* success) {
if (cols % 2 == 0) {
return TryDispatchSoftmaxGradBlockSMemImplBlockSize<LOAD_Y, LOAD_DY, STORE, ComputeType, 2,
......@@ -1185,7 +1210,7 @@ struct TryDispatchSoftmaxGradBlockSMemImplPackSize {
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
inline cudaError_t TryDispatchSoftmaxGradBlockSMemImpl(cudaStream_t stream, LOAD_Y load_y,
inline GPU(Error_t) TryDispatchSoftmaxGradBlockSMemImpl(GPU(Stream_t) stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols,
bool* success) {
......@@ -1216,7 +1241,7 @@ __global__ void SoftmaxGradBlockUncachedImpl(LOAD_Y load_y, LOAD_DY load_dy, STO
} else if (algorithm == Algorithm::kLogSoftmax) {
thread_sum += dy_pack[i];
} else {
__trap();
TRAP();
}
}
}
......@@ -1233,7 +1258,7 @@ __global__ void SoftmaxGradBlockUncachedImpl(LOAD_Y load_y, LOAD_DY load_dy, STO
} else if (algorithm == Algorithm::kLogSoftmax) {
dy_pack[i] -= Exp(y_pack[i]) * row_sum;
} else {
__trap();
TRAP();
}
}
store.template store<pack_size>(dy_pack, row, pack_id * pack_size);
......@@ -1243,26 +1268,26 @@ __global__ void SoftmaxGradBlockUncachedImpl(LOAD_Y load_y, LOAD_DY load_dy, STO
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
Algorithm algorithm>
inline cudaError_t LaunchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, LOAD_Y load_y,
inline GPU(Error_t) LaunchSoftmaxGradBlockUncachedImpl(GPU(Stream_t) stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
constexpr int block_size = 1024;
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != cudaSuccess) { return err; }
GPU(Error_t) err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != GPU(Success)) { return err; }
}
SoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size,
algorithm>
<<<grid_dim_x, block_size, 0, stream>>>(load_y, load_dy, store, rows, cols);
return cudaPeekAtLastError();
return GPU(PeekAtLastError)();
}
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
struct DispatchSoftmaxGradBlockUncachedImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
GPU(Error_t) operator()(GPU(Stream_t) stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols % 2 == 0 && cols > kWarpSize) {
return LaunchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, 2, algorithm>(
......@@ -1276,7 +1301,7 @@ struct DispatchSoftmaxGradBlockUncachedImplPackSize {
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
inline cudaError_t DispatchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, LOAD_Y load_y,
inline GPU(Error_t) DispatchSoftmaxGradBlockUncachedImpl(GPU(Stream_t) stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
return DispatchSoftmaxGradBlockUncachedImplPackSize<LOAD_Y, LOAD_DY, STORE, ComputeType,
......@@ -1285,8 +1310,8 @@ inline cudaError_t DispatchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, LOA
}
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, GPU(Error_t)>::type
DispatchSoftmaxGrad(GPU(Stream_t) stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols <= 1024) {
return DispatchSoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, Algorithm::kSoftmax>(
......@@ -1294,23 +1319,23 @@ DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE s
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err = TryDispatchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
GPU(Error_t) err = TryDispatchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kSoftmax>(
stream, load_y, load_dy, store, rows, cols, &dispatch_smem_impl_success);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (!dispatch_smem_impl_success) {
return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kSoftmax>(stream, load_y, load_dy,
store, rows, cols);
}
return cudaSuccess;
return GPU(Success);
}
}
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
inline typename std::enable_if<std::is_same<ComputeType, double>::value, GPU(Error_t)>::type
DispatchSoftmaxGrad(GPU(Stream_t) stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kSoftmax>(stream, load_y, load_dy, store,
......@@ -1318,8 +1343,8 @@ DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE s
}
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, GPU(Error_t)>::type
DispatchLogSoftmaxGrad(GPU(Stream_t) stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols <= 1024) {
return DispatchSoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, Algorithm::kLogSoftmax>(
......@@ -1327,23 +1352,23 @@ DispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STOR
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err = TryDispatchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
GPU(Error_t) err = TryDispatchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kLogSoftmax>(
stream, load_y, load_dy, store, rows, cols, &dispatch_smem_impl_success);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (!dispatch_smem_impl_success) {
return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kLogSoftmax>(stream, load_y, load_dy,
store, rows, cols);
}
return cudaSuccess;
return GPU(Success);
}
}
template<typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
inline typename std::enable_if<std::is_same<ComputeType, double>::value, GPU(Error_t)>::type
DispatchLogSoftmaxGrad(GPU(Stream_t) stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kLogSoftmax>(stream, load_y, load_dy,
......
......@@ -16,8 +16,14 @@ limitations under the License.
#ifndef ONEFLOW_CORE_CUDA_UNIQUE_H_
#define ONEFLOW_CORE_CUDA_UNIQUE_H_
#ifdef WITH_ROCM
#include "hip/hip_runtime.h"
#include <hipcub/hipcub.hpp>
#else
#include <cub/cub.cuh>
#include <device_launch_parameters.h>
#endif
#include "oneflow/core/common/permutation_iterator.h"
#include "oneflow/core/common/not_equal_to_previous_adjacent_iterator.h"
......@@ -49,82 +55,98 @@ __device__ __host__ __forceinline__ T* PtrOffset(void* ptr, size_t offset) {
__device__ __host__ __forceinline__ size_t max(size_t a, size_t b) { return a > b ? a : b; }
template<typename Key, typename Index>
cudaError_t DoUnique(size_t n, const Key* sorted_in, Key* unique, Index* num_unique,
void* workspace, size_t* workspace_size, cudaStream_t stream) {
GPU(Error_t) DoUnique(size_t n, const Key* sorted_in, Key* unique, Index* num_unique,
void* workspace, size_t* workspace_size, GPU(Stream_t) stream) {
size_t ws = *workspace_size;
cudaError_t err = cub::DeviceSelect::Unique<const Key*, Key*, Index*>(
#ifdef WITH_ROCM
GPU(Error_t) err = hipcub::DeviceSelect::Unique<const Key*, Key*, Index*>(
workspace, ws, sorted_in, unique, num_unique, n, stream);
#else
GPU(Error_t) err = cub::DeviceSelect::Unique<const Key*, Key*, Index*>(
workspace, ws, sorted_in, unique, num_unique, n, stream);
if (err != cudaSuccess) { return err; }
#endif
if (err != GPU(Success)) { return err; }
if (*workspace_size == 0) { *workspace_size = ws; }
return cudaSuccess;
return GPU(Success);
}
template<typename Key, typename Index>
cudaError_t DoUniqueWithCounts(size_t n, const Key* sorted_in, Key* unique, Index* num_unique,
GPU(Error_t) DoUniqueWithCounts(size_t n, const Key* sorted_in, Key* unique, Index* num_unique,
Index* counts, void* workspace, size_t* workspace_size,
cudaStream_t stream) {
GPU(Stream_t) stream) {
size_t ws = *workspace_size;
cudaError_t err = cub::DeviceRunLengthEncode::Encode<const Key*, Key*, Index*, Index*>(
#ifdef WITH_ROCM
GPU(Error_t) err = hipcub::DeviceRunLengthEncode::Encode<const Key*, Key*, Index*, Index*>(
workspace, ws, sorted_in, unique, counts, num_unique, n, stream);
#else
GPU(Error_t) err = cub::DeviceRunLengthEncode::Encode<const Key*, Key*, Index*, Index*>(
workspace, ws, sorted_in, unique, counts, num_unique, n, stream);
if (err != cudaSuccess) { return err; }
#endif
if (err != GPU(Success)) { return err; }
if (*workspace_size == 0) { *workspace_size = ws; }
return cudaSuccess;
return GPU(Success);
}
template<typename Key, typename Index>
cudaError_t DispatchOutputCounts(Flag flag, size_t n, const Key* sorted_in, Key* unique,
GPU(Error_t) DispatchOutputCounts(Flag flag, size_t n, const Key* sorted_in, Key* unique,
Index* num_unique, Index* counts, void* workspace,
size_t* workspace_size, cudaStream_t stream) {
size_t* workspace_size, GPU(Stream_t) stream) {
size_t ws = *workspace_size;
if ((flag & kOutputCounts) != 0) {
cudaError_t err = DoUniqueWithCounts<Key, Index>(n, sorted_in, unique, num_unique, counts,
GPU(Error_t) err = DoUniqueWithCounts<Key, Index>(n, sorted_in, unique, num_unique, counts,
workspace, &ws, stream);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
} else {
cudaError_t err =
GPU(Error_t) err =
DoUnique<Key, Index>(n, sorted_in, unique, num_unique, workspace, &ws, stream);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (*workspace_size == 0) { *workspace_size = ws; }
return cudaSuccess;
return GPU(Success);
}
template<typename Key, typename Index, typename InverseIndicesIter>
cudaError_t DoGenInverseIndices(size_t n, const Key* sorted_in,
GPU(Error_t) DoGenInverseIndices(size_t n, const Key* sorted_in,
InverseIndicesIter inverse_indices_iter, void* workspace,
size_t* workspace_size, cudaStream_t stream) {
size_t* workspace_size, GPU(Stream_t) stream) {
size_t ws = *workspace_size;
NotEqualToPreviousAdjacentIterator<Index, Key> unique_counting_iter(sorted_in, 0);
cudaError_t err =
#ifdef WITH_ROCM
GPU(Error_t) err =
hipcub::DeviceScan::InclusiveSum<decltype(unique_counting_iter), InverseIndicesIter>(
workspace, ws, unique_counting_iter, inverse_indices_iter, n, stream);
#else
GPU(Error_t) err =
cub::DeviceScan::InclusiveSum<decltype(unique_counting_iter), InverseIndicesIter>(
workspace, ws, unique_counting_iter, inverse_indices_iter, n, stream);
if (err != cudaSuccess) { return err; }
#endif
if (err != GPU(Success)) { return err; }
if (*workspace_size == 0) { *workspace_size = ws; }
return cudaSuccess;
return GPU(Success);
}
template<typename Key, typename Index, typename InverseIndicesIter>
cudaError_t DispatchOutputInverseIndices(Flag flag, size_t n, const Key* sorted_in, Key* unique,
GPU(Error_t) DispatchOutputInverseIndices(Flag flag, size_t n, const Key* sorted_in, Key* unique,
Index* num_unique, InverseIndicesIter inverse_indices_iter,
Index* counts, void* workspace, size_t* workspace_size,
cudaStream_t stream) {
GPU(Stream_t) stream) {
size_t dispatch_with_counts_ws = *workspace_size;
size_t do_gen_inverse_indices_ws = *workspace_size;
{
cudaError_t err =
GPU(Error_t) err =
DispatchOutputCounts<Key, Index>(flag, n, sorted_in, unique, num_unique, counts, workspace,
&dispatch_with_counts_ws, stream);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if ((flag & kOutputInverseIndices) != 0) {
cudaError_t err = DoGenInverseIndices<Key, Index, InverseIndicesIter>(
GPU(Error_t) err = DoGenInverseIndices<Key, Index, InverseIndicesIter>(
n, sorted_in, inverse_indices_iter, workspace, &do_gen_inverse_indices_ws, stream);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (*workspace_size == 0) {
*workspace_size = max(dispatch_with_counts_ws, do_gen_inverse_indices_ws);
}
return cudaSuccess;
return GPU(Success);
}
template<typename T>
......@@ -136,8 +158,8 @@ __global__ void IotaKernel(size_t n, T* out) {
}
template<typename Key, typename Index>
cudaError_t DoSort(size_t n, const Key* in, Key* sorted, Index* sorted_indices, void* workspace,
size_t* workspace_size, cudaStream_t stream) {
GPU(Error_t) DoSort(size_t n, const Key* in, Key* sorted, Index* sorted_indices, void* workspace,
size_t* workspace_size, GPU(Stream_t) stream) {
Index* indices;
const size_t indices_size = GetCudaAlignedSize(n * sizeof(Index));
void* sort_workspace;
......@@ -147,7 +169,7 @@ cudaError_t DoSort(size_t n, const Key* in, Key* sorted, Index* sorted_indices,
sort_workspace = nullptr;
sort_ws = 0;
} else {
if (*workspace_size <= indices_size) { return cudaErrorInvalidValue; }
if (*workspace_size <= indices_size) { return GPU(ErrorInvalidValue); }
indices = PtrOffset<Index>(workspace, 0);
sort_workspace = PtrOffset<Index>(workspace, indices_size);
sort_ws = *workspace_size - indices_size;
......@@ -157,17 +179,22 @@ cudaError_t DoSort(size_t n, const Key* in, Key* sorted, Index* sorted_indices,
const int num_blocks = static_cast<int>((n + block_size - 1) / block_size);
IotaKernel<Index><<<num_blocks, block_size, 0, stream>>>(n, indices);
}
cudaError_t err = cub::DeviceRadixSort::SortPairs<Key, Index>(
#ifdef WITH_ROCM
GPU(Error_t) err = hipcub::DeviceRadixSort::SortPairs<Key, Index>(
sort_workspace, sort_ws, in, sorted, indices, sorted_indices, n, 0, sizeof(Key) * 8, stream);
#else
GPU(Error_t) err = cub::DeviceRadixSort::SortPairs<Key, Index>(
sort_workspace, sort_ws, in, sorted, indices, sorted_indices, n, 0, sizeof(Key) * 8, stream);
if (err != cudaSuccess) { return err; }
#endif
if (err != GPU(Success)) { return err; }
if (*workspace_size == 0) { *workspace_size = indices_size + sort_ws; }
return cudaSuccess;
return GPU(Success);
}
template<typename Key, typename Index>
cudaError_t DispatchInputSorted(Flag flag, size_t n, const Key* in, Key* unique, Index* num_unique,
GPU(Error_t) DispatchInputSorted(Flag flag, size_t n, const Key* in, Key* unique, Index* num_unique,
Index* inverse_indices, Index* counts, void* workspace,
size_t* workspace_size, cudaStream_t stream) {
size_t* workspace_size, GPU(Stream_t) stream) {
if ((flag & kInputSorted) != 0) {
return DispatchOutputInverseIndices<Key, Index, Index*>(flag, n, in, unique, num_unique,
inverse_indices, counts, workspace,
......@@ -190,7 +217,7 @@ cudaError_t DispatchInputSorted(Flag flag, size_t n, const Key* in, Key* unique,
do_inverse_indices_ws = 0;
do_inverse_indices_workspace = nullptr;
} else {
if (*workspace_size <= sort_buffer_size) { return cudaErrorInvalidValue; }
if (*workspace_size <= sort_buffer_size) { return GPU(ErrorInvalidValue); }
sorted_in = PtrOffset<Key>(workspace, 0);
sorted_indices = PtrOffset<Index>(workspace, sorted_in_size);
do_sort_ws = *workspace_size - sort_buffer_size;
......@@ -199,38 +226,38 @@ cudaError_t DispatchInputSorted(Flag flag, size_t n, const Key* in, Key* unique,
do_inverse_indices_workspace = do_sort_workspace;
}
{
cudaError_t err = DoSort<Key, Index>(n, in, sorted_in, sorted_indices, do_sort_workspace,
GPU(Error_t) err = DoSort<Key, Index>(n, in, sorted_in, sorted_indices, do_sort_workspace,
&do_sort_ws, stream);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
PermutationIterator<Index, Index*, Index*> inverse_indices_iter(inverse_indices,
sorted_indices);
{
cudaError_t err = DispatchOutputInverseIndices<Key, Index, decltype(inverse_indices_iter)>(
GPU(Error_t) err = DispatchOutputInverseIndices<Key, Index, decltype(inverse_indices_iter)>(
flag, n, sorted_in, unique, num_unique, inverse_indices_iter, counts,
do_inverse_indices_workspace, &do_inverse_indices_ws, stream);
if (err != cudaSuccess) { return err; }
if (err != GPU(Success)) { return err; }
}
if (*workspace_size == 0) {
*workspace_size = sort_buffer_size + max(do_sort_ws, do_inverse_indices_ws);
}
return cudaSuccess;
return GPU(Success);
}
}
} // namespace
template<typename Key, typename Index>
cudaError_t Launch(Flag flag, size_t n, const Key* in, Key* unique, Index* num_unique,
GPU(Error_t) Launch(Flag flag, size_t n, const Key* in, Key* unique, Index* num_unique,
Index* inverse_indices, Index* counts, void* workspace, size_t workspace_size,
cudaStream_t stream) {
if (workspace_size == 0) { return cudaErrorInvalidValue; }
GPU(Stream_t) stream) {
if (workspace_size == 0) { return GPU(ErrorInvalidValue); }
return DispatchInputSorted<Key, Index>(flag, n, in, unique, num_unique, inverse_indices, counts,
workspace, &workspace_size, stream);
}
template<typename Key, typename Index>
cudaError_t GetWorkspaceSize(Flag flag, size_t n, size_t* workspace_size) {
GPU(Error_t) GetWorkspaceSize(Flag flag, size_t n, size_t* workspace_size) {
*workspace_size = 0;
return DispatchInputSorted<Key, Index>(flag, n, nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, workspace_size, 0);
......
......@@ -23,11 +23,7 @@ limitations under the License.
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/platform/include/pthread_fork.h"
#include "oneflow/core/device/device_context.h"
#ifdef WITH_ROCM
#include "oneflow/core/ep/rocm/cuda_stream.h"
#else
#include "oneflow/core/ep/cuda/cuda_stream.h"
#endif
#include "oneflow/core/vm/vm_util.h"
#ifdef WITH_CUDA
......@@ -193,6 +189,10 @@ Maybe<double> GetCUDAMemoryUsed() {
int deviceCount = 0;
cudaError_t error_id = cudaGetDeviceCount(&deviceCount);
if (error_id != cudaSuccess) {
return Error::RuntimeError() << "Error: GetCUDAMemoryUsed fails :"
<< cudaGetErrorString(error_id);
}
CHECK_OR_RETURN(deviceCount > 0) << "GPU device does not exist";
......@@ -209,6 +209,26 @@ Maybe<double> GetCUDAMemoryUsed() {
return (total_memory - free_memory);
}
static std::once_flag prop_init_flag;
static std::vector<cudaDeviceProp> device_props;
void InitDevicePropVectorSize() {
int device_count = GetCudaDeviceCount();
device_props.resize(device_count);
}
void InitDeviceProperties(int device_id) {
std::call_once(prop_init_flag, InitDevicePropVectorSize);
cudaDeviceProp prop{};
OF_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_id));
device_props[device_id] = prop;
}
cudaDeviceProp* GetDeviceProperties(int device_id) {
InitCudaContextOnce(device_id);
return &device_props[device_id];
}
void InitCudaContextOnce(int device_id) {
static int device_count = GetCudaDeviceCount();
static std::vector<std::once_flag> init_flags = std::vector<std::once_flag>(device_count);
......@@ -217,6 +237,7 @@ void InitCudaContextOnce(int device_id) {
std::call_once(init_flags[device_id], [&]() {
OF_CUDA_CHECK(cudaSetDevice(device_id));
OF_CUDA_CHECK(cudaDeviceSynchronize());
InitDeviceProperties(device_id);
});
}
......@@ -361,6 +382,10 @@ Maybe<double> GetCUDAMemoryUsed() {
int deviceCount = 0;
hipError_t error_id = hipGetDeviceCount(&deviceCount);
if (error_id != hipSuccess) {
return Error::RuntimeError() << "Error: GetCUDAMemoryUsed fails :"
<< hipGetErrorString(error_id);
}
CHECK_OR_RETURN(deviceCount > 0) << "GPU device does not exist";
......@@ -377,6 +402,26 @@ Maybe<double> GetCUDAMemoryUsed() {
return (total_memory - free_memory);
}
static std::once_flag prop_init_flag;
static std::vector<hipDeviceProp_t> device_props;
void InitDevicePropVectorSize() {
int device_count = GetCudaDeviceCount();
device_props.resize(device_count);
}
void InitDeviceProperties(int device_id) {
std::call_once(prop_init_flag, InitDevicePropVectorSize);
hipDeviceProp_t prop{};
OF_CUDA_CHECK(hipGetDeviceProperties(&prop, device_id));
device_props[device_id] = prop;
}
hipDeviceProp_t* GetDeviceProperties(int device_id) {
InitCudaContextOnce(device_id);
return &device_props[device_id];
}
void InitCudaContextOnce(int device_id) {
static int device_count = GetCudaDeviceCount();
static std::vector<std::once_flag> init_flags = std::vector<std::once_flag>(device_count);
......@@ -385,11 +430,10 @@ void InitCudaContextOnce(int device_id) {
std::call_once(init_flags[device_id], [&]() {
OF_CUDA_CHECK(hipSetDevice(device_id));
OF_CUDA_CHECK(hipDeviceSynchronize());
InitDeviceProperties(device_id);
});
}
#endif // WITH_ROCM
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <cub/cub.cuh>
#include "oneflow/core/device/cuda_util.h"
namespace oneflow {
int GetCudaSmVersion() {
int sm_version, device_ordinal;
OF_CUDA_CHECK(cudaGetDevice(&device_ordinal));
OF_CUDA_CHECK(cub::SmVersion(sm_version, device_ordinal));
return sm_version;
}
int GetCudaPtxVersion() {
int ptx_version;
OF_CUDA_CHECK(cub::PtxVersion(ptx_version));
return ptx_version;
}
} // namespace oneflow
......@@ -30,6 +30,9 @@ limitations under the License.
#include <curand.h>
#include <nccl.h>
#include <cuda_fp16.h>
#if CUDA_VERSION >= 11000
#include <cuda_bf16.h>
#endif // CUDA_VERSION >= 11000
#include "oneflow/core/device/cuda_pseudo_half.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
......@@ -82,7 +85,10 @@ const char* NvjpegGetErrorString(nvjpegStatus_t error);
#define OF_NCCL_CHECK_OR_RETURN(condition) \
for (ncclResult_t _of_nccl_check_status = (condition); _of_nccl_check_status != ncclSuccess;) \
return Error::CheckFailedError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) \
return Error::CheckFailedError().AddStackFrame([](const char* function) { \
thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \
return frame; \
}(__FUNCTION__)) \
<< "Check failed: " #condition " : " << ncclGetErrorString(_of_nccl_check_status) << " (" \
<< _of_nccl_check_status << ") "
......@@ -152,16 +158,14 @@ class CublasMathModeGuard final {
cublasMath_t new_mode_{};
};
int GetCudaSmVersion();
int GetCudaPtxVersion();
int GetCudaDeviceIndex();
int GetCudaDeviceCount();
Maybe<double> GetCUDAMemoryUsed();
cudaDeviceProp* GetDeviceProperties(int device_id);
void SetCudaDeviceIndex(int device_id);
void CudaSynchronize(int device_id);
......@@ -184,7 +188,11 @@ cudaError_t CudaDriverGetPrimaryCtxActive(int dev, int* active);
#include <rccl.h>
#include <hip/hip_fp16.h>
#include "oneflow/core/device/cuda_pseudo_half.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
// #if CUDA_VERSION >= 11000
// #include <cuda_bf16.h>
// #endif // CUDA_VERSION >= 11000
namespace oneflow {
......@@ -223,7 +231,10 @@ const char* CurandGetErrorString(hiprandStatus_t error);
#define OF_NCCL_CHECK_OR_RETURN(condition) \
for (ncclResult_t _of_nccl_check_status = (condition); _of_nccl_check_status != ncclSuccess;) \
return Error::CheckFailedError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) \
return Error::CheckFailedError().AddStackFrame([](const char* function) { \
thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \
return frame; \
}(__FUNCTION__)) \
<< "Check failed: " #condition " : " << ncclGetErrorString(_of_nccl_check_status) << " (" \
<< _of_nccl_check_status << ") "
......@@ -275,6 +286,8 @@ int GetCudaDeviceCount();
Maybe<double> GetCUDAMemoryUsed();
hipDeviceProp_t* GetDeviceProperties(int device_id);
void SetCudaDeviceIndex(int device_id);
void CudaSynchronize(int device_id);
......
......@@ -341,7 +341,10 @@ ManagedCudnnConvResource::ManagedCudnnConvResource(const CudnnConvArgs& args)
}
ManagedCudnnConvResource::~ManagedCudnnConvResource() {
if (handle_ != nullptr) { OF_CUDNN_CHECK(cudnnDestroy(handle_)); }
if (handle_ != nullptr) {
Singleton<CudnnHandlePool>::Get()->Put(handle_);
handle_ = nullptr;
}
if (x_dptr_ != nullptr) { OF_CUDA_CHECK(cudaFree(x_dptr_)); }
if (w_dptr_ != nullptr) { OF_CUDA_CHECK(cudaFree(w_dptr_)); }
if (y_dptr_ != nullptr) { OF_CUDA_CHECK(cudaFree(y_dptr_)); }
......@@ -349,7 +352,7 @@ ManagedCudnnConvResource::~ManagedCudnnConvResource() {
}
cudnnHandle_t ManagedCudnnConvResource::cudnn_handle() {
if (handle_ == nullptr) { OF_CUDNN_CHECK(cudnnCreate(&handle_)); }
if (handle_ == nullptr) { handle_ = Singleton<CudnnHandlePool>::Get()->Get(); }
return handle_;
}
......@@ -392,7 +395,12 @@ bool operator==(const CudnnConvParams& a, const CudnnConvParams& b) {
}
DataType GetConvDescDataType(DataType data_type, bool pseudo_half) {
return (data_type == DataType::kFloat16 && pseudo_half) ? DataType::kFloat : data_type;
if (data_type == DataType::kFloat16 && pseudo_half) {
return DataType::kFloat;
} else if (data_type == DataType::kBFloat16) {
return DataType::kFloat;
}
return data_type;
}
cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res,
......@@ -669,25 +677,6 @@ perf_t GetBestAlgorithm(const CudnnConvArgs& args, CudnnConvResource* res,
<< ") requires memory " << perf_vec[0].memory;
}
// #if HIPDNN_VERSION < 7500
// // google [blacklist fft algorithms for strided dgrad]
// if (std::is_same<decltype(perf_vec[found_algo_idx].algo), hipdnnConvolutionBwdDataAlgo_t>::value) {
// int stride_dim = args.params.x_ndim - 2;
// bool blacklist =
// std::any_of(std::begin(args.params.stride), std::begin(args.params.stride) + stride_dim,
// [](int n) { return n != 1; });
// if (blacklist
// && (static_cast<hipdnnConvolutionBwdDataAlgo_t>(perf_vec[found_algo_idx].algo)
// == HIPDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING
// || static_cast<hipdnnConvolutionBwdDataAlgo_t>(perf_vec[found_algo_idx].algo)
// == HIPDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
// perf_t algo_perf;
// SetAlgo4Perf(args, res, &algo_perf, GetDefaultAlgo<algo_t>());
// return algo_perf;
// }
// }
// #endif
return perf_vec.at(found_algo_idx);
}
......@@ -864,39 +853,37 @@ CudnnConvArgs::CudnnConvArgs(const user_op::InferContext& ctx, DataType x_data_t
wdesc(w_data_type, w_shape, data_format),
cdesc(GetConvDescDataType(x_data_type, enable_pseudo_half), x_data_type, x_shape, ctx),
heuristic(heuristic_search),
deterministic(use_deterministic_algo_only) {
std::memset(&params, 0, sizeof(CudnnConvParams));
OF_CUDNN_CHECK(hipdnnGetTensorNdDescriptor(xdesc.Get(), CudnnConvParams::kTensorMaxDims,
&params.x_data_type, &params.x_ndim, params.x_dims,
params.x_strides));
OF_CUDNN_CHECK(hipdnnGetTensorNdDescriptor(ydesc.Get(), CudnnConvParams::kTensorMaxDims,
&params.y_data_type, &params.y_ndim, params.y_dims,
params.y_strides));
OF_CUDNN_CHECK(hipdnnGetFilterNdDescriptor(wdesc.Get(), CudnnConvParams::kTensorMaxDims,
&params.w_data_type, &params.w_format, &params.w_ndim,
params.w_dims));
hipdnnConvolutionMode_t mode;
int conv_dim_size = x_shape.NumAxes() - 2;
for (int i=0; i<3; i++) {
params.padding[i] = cdesc.CD_padding[i];
params.stride[i] = cdesc.CD_stride[i];
params.dilation[i] = cdesc.CD_dilation[i];
}
deterministic(use_deterministic_algo_only),
max_ws_size(max_workspace_size) {
// std::memset(&params, 0, sizeof(CudnnConvParams));
// OF_CUDNN_CHECK(hipdnnGetTensorNdDescriptor(xdesc.Get(), CudnnConvParams::kTensorMaxDims,
// &params.x_data_type, &params.x_ndim, params.x_dims,
// params.x_strides));
// OF_CUDNN_CHECK(hipdnnGetTensorNdDescriptor(ydesc.Get(), CudnnConvParams::kTensorMaxDims,
// &params.y_data_type, &params.y_ndim, params.y_dims,
// params.y_strides));
// OF_CUDNN_CHECK(hipdnnGetFilterNdDescriptor(wdesc.Get(), CudnnConvParams::kTensorMaxDims,
// &params.w_data_type, &params.w_format, &params.w_ndim,
// params.w_dims));
// hipdnnConvolutionMode_t mode;
// int conv_dim_size = x_shape.NumAxes() - 2;
// for (int i=0; i<3; i++) {
// params.padding[i] = cdesc.CD_padding[i];
// params.stride[i] = cdesc.CD_stride[i];
// params.dilation[i] = cdesc.CD_dilation[i];
// }
mode = cdesc.CD_mode;
params.data_type = cdesc.CD_data_type;
// mode = cdesc.CD_mode;
// params.data_type = cdesc.CD_data_type;
// OF_CUDNN_CHECK(cudnnGetConvolutionNdDescriptor(cdesc.Get(), CudnnConvParams::kConvMaxDims,
// &conv_dim_size, params.padding, params.stride,
// params.dilation, &mode, &params.data_type));
CHECK_EQ(params.x_data_type, params.w_data_type);
CHECK_EQ(params.x_ndim, params.w_ndim);
// CHECK_EQ(conv_dim_size + 2, params.x_ndim);
// CHECK_EQ(params.x_data_type, params.w_data_type);
// CHECK_EQ(params.x_ndim, params.w_ndim);
// // CHECK_EQ(conv_dim_size + 2, params.x_ndim);
// params.groups = cdesc.CD_groups;
// params.max_ws_size = max_workspace_size;
params.groups = cdesc.CD_groups;
// OF_CUDNN_CHECK(cudnnGetConvolutionGroupCount(cdesc.Get(), &params.groups));
params.max_ws_size = max_workspace_size;
}
CudnnConvArgs::CudnnConvArgs(const user_op::KernelComputeContext& ctx, DataType x_data_type,
......@@ -910,51 +897,56 @@ CudnnConvArgs::CudnnConvArgs(const user_op::KernelComputeContext& ctx, DataType
wdesc(w_data_type, w_shape, data_format),
cdesc(GetConvDescDataType(x_data_type, enable_pseudo_half), x_data_type, x_shape, ctx),
heuristic(heuristic_search),
deterministic(use_deterministic_algo_only) {
std::memset(&params, 0, sizeof(CudnnConvParams));
OF_CUDNN_CHECK(hipdnnGetTensorNdDescriptor(xdesc.Get(), CudnnConvParams::kTensorMaxDims,
&params.x_data_type, &params.x_ndim, params.x_dims,
params.x_strides));
OF_CUDNN_CHECK(hipdnnGetTensorNdDescriptor(ydesc.Get(), CudnnConvParams::kTensorMaxDims,
&params.y_data_type, &params.y_ndim, params.y_dims,
params.y_strides));
OF_CUDNN_CHECK(hipdnnGetFilterNdDescriptor(wdesc.Get(), CudnnConvParams::kTensorMaxDims,
&params.w_data_type, &params.w_format, &params.w_ndim,
params.w_dims));
hipdnnConvolutionMode_t mode;
int conv_dim_size = x_shape.NumAxes() - 2;
for (int i=0; i<3; i++) {
params.padding[i] = cdesc.CD_padding[i];
params.stride[i] = cdesc.CD_stride[i];
params.dilation[i] = cdesc.CD_dilation[i];
}
deterministic(use_deterministic_algo_only),
max_ws_size(max_workspace_size) {
// std::memset(&params, 0, sizeof(CudnnConvParams));
// OF_CUDNN_CHECK(hipdnnGetTensorNdDescriptor(xdesc.Get(), CudnnConvParams::kTensorMaxDims,
// &params.x_data_type, &params.x_ndim, params.x_dims,
// params.x_strides));
// OF_CUDNN_CHECK(hipdnnGetTensorNdDescriptor(ydesc.Get(), CudnnConvParams::kTensorMaxDims,
// &params.y_data_type, &params.y_ndim, params.y_dims,
// params.y_strides));
// OF_CUDNN_CHECK(hipdnnGetFilterNdDescriptor(wdesc.Get(), CudnnConvParams::kTensorMaxDims,
// &params.w_data_type, &params.w_format, &params.w_ndim,
// params.w_dims));
// hipdnnConvolutionMode_t mode;
// int conv_dim_size = x_shape.NumAxes() - 2;
// for (int i=0; i<3; i++) {
// params.padding[i] = cdesc.CD_padding[i];
// params.stride[i] = cdesc.CD_stride[i];
// params.dilation[i] = cdesc.CD_dilation[i];
// }
mode = cdesc.CD_mode;
params.data_type = cdesc.CD_data_type;
// mode = cdesc.CD_mode;
// params.data_type = cdesc.CD_data_type;
// OF_CUDNN_CHECK(cudnnGetConvolutionNdDescriptor(cdesc.Get(), CudnnConvParams::kConvMaxDims,
// &conv_dim_size, params.padding, params.stride,
// params.dilation, &mode, &params.data_type));
CHECK_EQ(params.x_data_type, params.w_data_type);
CHECK_EQ(params.x_ndim, params.w_ndim);
// CHECK_EQ(conv_dim_size + 2, params.x_ndim);
// CHECK_EQ(params.x_data_type, params.w_data_type);
// CHECK_EQ(params.x_ndim, params.w_ndim);
// // CHECK_EQ(conv_dim_size + 2, params.x_ndim);
// params.groups = cdesc.CD_groups;
// params.max_ws_size = max_workspace_size;
params.groups = cdesc.CD_groups;
// OF_CUDNN_CHECK(cudnnGetConvolutionGroupCount(cdesc.Get(), &params.groups));
params.max_ws_size = max_workspace_size;
}
ManagedCudnnConvResource::ManagedCudnnConvResource(const CudnnConvArgs& args)
: handle_(nullptr), x_dptr_(nullptr), w_dptr_(nullptr), y_dptr_(nullptr), ws_dptr_(nullptr) {
x_byte_size_ = ByteSize4Tensor(args.params.x_dims, args.params.x_ndim, args.params.x_data_type);
w_byte_size_ = ByteSize4Tensor(args.params.w_dims, args.params.w_ndim, args.params.w_data_type);
y_byte_size_ = ByteSize4Tensor(args.params.y_dims, args.params.y_ndim, args.params.y_data_type);
ws_byte_size_ = args.params.max_ws_size;
// x_byte_size_ = ByteSize4Tensor(args.params.x_dims, args.params.x_ndim, args.params.x_data_type);
// w_byte_size_ = ByteSize4Tensor(args.params.w_dims, args.params.w_ndim, args.params.w_data_type);
// y_byte_size_ = ByteSize4Tensor(args.params.y_dims, args.params.y_ndim, args.params.y_data_type);
// ws_byte_size_ = args.params.max_ws_size;
x_byte_size_ = 0;
w_byte_size_ = 0;
y_byte_size_ = 0;
ws_byte_size_ = 0;
}
ManagedCudnnConvResource::~ManagedCudnnConvResource() {
if (handle_ != nullptr) { OF_CUDNN_CHECK(hipdnnDestroy(handle_)); }
if (handle_ != nullptr) {
Singleton<CudnnHandlePool>::Get()->Put(handle_);
handle_ = nullptr;
}
if (x_dptr_ != nullptr) { OF_CUDA_CHECK(hipFree(x_dptr_)); }
if (w_dptr_ != nullptr) { OF_CUDA_CHECK(hipFree(w_dptr_)); }
if (y_dptr_ != nullptr) { OF_CUDA_CHECK(hipFree(y_dptr_)); }
......@@ -962,7 +954,7 @@ ManagedCudnnConvResource::~ManagedCudnnConvResource() {
}
hipdnnHandle_t ManagedCudnnConvResource::cudnn_handle() {
if (handle_ == nullptr) { OF_CUDNN_CHECK(hipdnnCreate(&handle_)); }
if (handle_ == nullptr) { handle_ = Singleton<CudnnHandlePool>::Get()->Get(); }
return handle_;
}
......@@ -1005,7 +997,12 @@ bool operator==(const CudnnConvParams& a, const CudnnConvParams& b) {
}
DataType GetConvDescDataType(DataType data_type, bool pseudo_half) {
return (data_type == DataType::kFloat16 && pseudo_half) ? DataType::kFloat : data_type;
if (data_type == DataType::kFloat16 && pseudo_half) {
return DataType::kFloat;
} else if (data_type == DataType::kBFloat16) {
return DataType::kFloat;
}
return data_type;
}
hipdnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res,
......@@ -1035,38 +1032,18 @@ struct CudnnConvAlgorithmSearch<hipdnnConvolutionFwdAlgoPerf_t> {
static int GetAlgoMaxCount(CudnnConvResource* res) {
int max_algo_cnt = 1;
// OF_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithmMaxCount(res->cudnn_handle(), &max_algo_cnt));
return max_algo_cnt;
}
// static void HeuristicSearch(const CudnnConvArgs& args, CudnnConvResource* res,
// std::vector<perf_t>* perf_vec) {
// int found_algo_cnt = 0;
// perf_vec->resize(GetAlgoMaxCount(res));
// OF_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7(
// res->cudnn_handle(), args.xdesc.Get(), args.wdesc.Get(), args.cdesc.Get(), args.ydesc.Get(),
// perf_vec->size(), &found_algo_cnt, perf_vec->data()));
// // vector::resize does not affect the first found_algo_cnt elements.
// perf_vec->resize(found_algo_cnt);
// }
static void ExhaustiveSearch(CudnnConvArgs& args, CudnnConvResource* res,
perf_t* perf) {
int found_algo_cnt = 0;
size_t ws = 0;
hipdnnConvolutionFwdAlgo_t algo;
hipdnnGetConvolutionForwardWorkspaceSize(res->cudnn_handle(), args.xdesc.Get(),
args.wdesc.Get(), args.cdesc.Get(),
args.ydesc.Get(), algo, &ws);
res->ws_byte_size_ = ws;
res->set_ws();
args.params.max_ws_size = ws;
OF_CUDNN_CHECK(hipdnnFindConvolutionForwardAlgorithmEx(
res->cudnn_handle(), args.xdesc.Get(), res->x_const_dptr(), args.wdesc.Get(),
res->w_const_dptr(), args.cdesc.Get(), args.ydesc.Get(), res->y_mut_dptr(),
1, &found_algo_cnt, perf, res->ws_dptr(),
args.params.max_ws_size));
args.max_ws_size));
}
};
......@@ -1076,40 +1053,18 @@ struct CudnnConvAlgorithmSearch<hipdnnConvolutionBwdDataAlgoPerf_t> {
static int GetAlgoMaxCount(CudnnConvResource* res) {
int max_algo_cnt = 1;
// OF_CUDNN_CHECK(
// cudnnGetConvolutionBackwardDataAlgorithmMaxCount(res->cudnn_handle(), &max_algo_cnt));
return max_algo_cnt;
}
// static void HeuristicSearch(const CudnnConvArgs& args, CudnnConvResource* res,
// std::vector<perf_t>* perf_vec) {
// int found_algo_cnt = 0;
// perf_vec->resize(GetAlgoMaxCount(res));
// OF_CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm_v7(
// res->cudnn_handle(), args.wdesc.Get(), args.ydesc.Get(), args.cdesc.Get(), args.xdesc.Get(),
// perf_vec->size(), &found_algo_cnt, perf_vec->data()));
// // vector::resize does not affect the first found_algo_cnt elements.
// perf_vec->resize(found_algo_cnt);
// }
static void ExhaustiveSearch(CudnnConvArgs& args, CudnnConvResource* res,
perf_t* perf) {
int found_algo_cnt = 0;
size_t ws = 0;
hipdnnConvolutionBwdDataAlgo_t algo;
hipdnnGetConvolutionBackwardDataWorkspaceSize(res->cudnn_handle(), args.wdesc.Get(),
args.ydesc.Get(), args.cdesc.Get(),
args.xdesc.Get(), algo, &ws);
res->ws_byte_size_ = ws;
res->set_ws();
args.params.max_ws_size = ws;
OF_CUDNN_CHECK(hipdnnFindConvolutionBackwardDataAlgorithmEx(
res->cudnn_handle(), args.wdesc.Get(), res->w_const_dptr(), args.ydesc.Get(),
res->y_const_dptr(), args.cdesc.Get(), args.xdesc.Get(), res->x_mut_dptr(),
1, &found_algo_cnt, perf, res->ws_dptr(),
args.params.max_ws_size));
args.max_ws_size));
}
};
......@@ -1119,40 +1074,18 @@ struct CudnnConvAlgorithmSearch<hipdnnConvolutionBwdFilterAlgoPerf_t> {
static int GetAlgoMaxCount(CudnnConvResource* res) {
int max_algo_cnt = 1;
// OF_CUDNN_CHECK(
// cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(res->cudnn_handle(), &max_algo_cnt));
return max_algo_cnt;
}
// static void HeuristicSearch(const CudnnConvArgs& args, CudnnConvResource* res,
// std::vector<perf_t>* perf_vec) {
// int found_algo_cnt = 0;
// perf_vec->resize(GetAlgoMaxCount(res));
// OF_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
// res->cudnn_handle(), args.xdesc.Get(), args.ydesc.Get(), args.cdesc.Get(), args.wdesc.Get(),
// perf_vec->size(), &found_algo_cnt, perf_vec->data()));
// // vector::resize does not affect the first found_algo_cnt elements.
// perf_vec->resize(found_algo_cnt);
// }
static void ExhaustiveSearch(CudnnConvArgs& args, CudnnConvResource* res,
perf_t* perf) {
int found_algo_cnt = 0;
size_t ws = 0;
hipdnnConvolutionBwdFilterAlgo_t algo;
hipdnnGetConvolutionBackwardFilterWorkspaceSize(res->cudnn_handle(), args.xdesc.Get(),
args.ydesc.Get(), args.cdesc.Get(),
args.wdesc.Get(), algo, &ws);
res->ws_byte_size_ = ws;
res->set_ws();
args.params.max_ws_size = ws;
OF_CUDNN_CHECK(hipdnnFindConvolutionBackwardFilterAlgorithmEx(
res->cudnn_handle(), args.xdesc.Get(), res->x_const_dptr(), args.ydesc.Get(),
res->y_const_dptr(), args.cdesc.Get(), args.wdesc.Get(), res->w_mut_dptr(),
1, &found_algo_cnt, perf, res->ws_dptr(),
args.params.max_ws_size));
args.max_ws_size));
}
};
......@@ -1200,4 +1133,6 @@ EXPLICIT_INSTANTIAT_CUDNN_CONV_ALGORITHM_INTERFACE(hipdnnConvolutionBwdFilterAlg
} // namespace oneflow
#endif // WITH_ROCM
......@@ -308,6 +308,7 @@ struct CudnnConvArgs final {
CudnnConvDesc cdesc;
bool heuristic;
bool deterministic;
size_t max_ws_size;
OF_DISALLOW_COPY_AND_MOVE(CudnnConvArgs);
CudnnConvArgs(const user_op::InferContext& ctx, DataType x_data_type, const ShapeView& x_shape,
......@@ -333,17 +334,14 @@ class CudnnConvResource {
virtual const void* x_const_dptr() const = 0;
virtual const void* y_const_dptr() const = 0;
virtual void* ws_dptr() = 0;
virtual void set_ws() = 0;
size_t ws_byte_size_;
};
class AllocatedCudnnConvResource final : public CudnnConvResource {
public:
AllocatedCudnnConvResource(hipdnnHandle_t handle, void* x_dptr, void* w_dptr, void* y_dptr,
void* ws_dptr, size_t ws_byte_size)
: handle_(handle), x_dptr_(x_dptr), w_dptr_(w_dptr), y_dptr_(y_dptr), ws_dptr_(ws_dptr), ws_byte_size_(ws_byte_size) {}
// ~AllocatedCudnnConvResource() = default;
~AllocatedCudnnConvResource(){if (ws_dptr_ != nullptr) { OF_CUDA_CHECK(hipFree(ws_dptr_)); }}
void* ws_dptr)
: handle_(handle), x_dptr_(x_dptr), w_dptr_(w_dptr), y_dptr_(y_dptr), ws_dptr_(ws_dptr) {}
~AllocatedCudnnConvResource() = default;
hipdnnHandle_t cudnn_handle() override { return handle_; }
const void* x_const_dptr() const override { return x_dptr_; }
const void* w_const_dptr() const override { return w_dptr_; }
......@@ -351,13 +349,7 @@ class AllocatedCudnnConvResource final : public CudnnConvResource {
void* x_mut_dptr() override { return x_dptr_; }
void* w_mut_dptr() override { return w_dptr_; }
void* y_mut_dptr() override { return y_dptr_; }
void* ws_dptr() override {
// return ws_dptr_;
if (ws_dptr_ == nullptr) { OF_CUDA_CHECK(hipMalloc(&ws_dptr_, ws_byte_size_)); }
return ws_dptr_;
}
void set_ws() { ws_byte_size_ = CudnnConvResource::ws_byte_size_; }
size_t ws_byte_size_;
void* ws_dptr() override { return ws_dptr_; }
private:
hipdnnHandle_t handle_;
......@@ -379,8 +371,6 @@ class ManagedCudnnConvResource final : public CudnnConvResource {
const void* w_const_dptr() const override;
const void* y_const_dptr() const override;
void* ws_dptr() override;
void set_ws(){ ws_byte_size_ = CudnnConvResource::ws_byte_size_; }
size_t ws_byte_size_;
private:
hipdnnHandle_t handle_;
......@@ -391,7 +381,7 @@ class ManagedCudnnConvResource final : public CudnnConvResource {
size_t x_byte_size_;
size_t w_byte_size_;
size_t y_byte_size_;
// size_t ws_byte_size_;
size_t ws_byte_size_;
};
bool operator==(const CudnnConvParams& a, const CudnnConvParams& b);
......
......@@ -177,6 +177,45 @@ size_t GetCudnnDataTypeByteSize(cudnnDataType_t data_type) {
return byte_size;
}
CudnnHandlePool::~CudnnHandlePool() {
for (auto& pair : handle_list_map_) {
int64_t device_id = pair.first;
auto& handle_list = pair.second;
CudaCurrentDeviceGuard guard(device_id);
while (!handle_list.empty()) {
cudnnHandle_t handle = handle_list.back();
handle_list.pop_back();
OF_CUDNN_CHECK(cudnnDestroy(handle));
}
}
handle_list_map_.clear();
}
cudnnHandle_t CudnnHandlePool::Get() {
int device_id;
OF_CUDA_CHECK(cudaGetDevice(&device_id));
{
std::unique_lock<std::mutex> lock(mutex_);
std::vector<cudnnHandle_t>& handle_list = handle_list_map_[device_id];
if (!handle_list.empty()) {
cudnnHandle_t handle = handle_list.back();
handle_list.pop_back();
return handle;
}
}
cudnnHandle_t handle;
OF_CUDNN_CHECK(cudnnCreate(&handle));
return handle;
}
void CudnnHandlePool::Put(cudnnHandle_t handle) {
int device_id;
OF_CUDA_CHECK(cudaGetDevice(&device_id));
std::unique_lock<std::mutex> lock(mutex_);
std::vector<cudnnHandle_t>& handle_list = handle_list_map_[device_id];
handle_list.push_back(handle);
}
#endif // WITH_CUDA
#ifdef WITH_ROCM
......@@ -302,10 +341,6 @@ size_t GetCudnnDataTypeByteSize(hipdnnDataType_t data_type) {
case HIPDNN_DATA_FLOAT:
case HIPDNN_DATA_INT32:
case HIPDNN_DATA_INT8x4:
// case CUDNN_DATA_UINT8x4: {
// byte_size = 4;
// break;
// }
case HIPDNN_DATA_DOUBLE: {
byte_size = 8;
break;
......@@ -315,22 +350,9 @@ size_t GetCudnnDataTypeByteSize(hipdnnDataType_t data_type) {
break;
}
case HIPDNN_DATA_INT8: {
// case CUDNN_DATA_UINT8: {
byte_size = 1;
break;
}
// #if HIPDNN_VERSION > 7200
// case CUDNN_DATA_INT8x32: {
// byte_size = 32;
// break;
// }
// #endif
// #if HIPDNN_VERSION >= 8100
// case CUDNN_DATA_BFLOAT16: {
// byte_size = 2;
// break;
// }
// #endif
default: {
UNIMPLEMENTED();
}
......@@ -338,6 +360,45 @@ size_t GetCudnnDataTypeByteSize(hipdnnDataType_t data_type) {
return byte_size;
}
CudnnHandlePool::~CudnnHandlePool() {
for (auto& pair : handle_list_map_) {
int64_t device_id = pair.first;
auto& handle_list = pair.second;
CudaCurrentDeviceGuard guard(device_id);
while (!handle_list.empty()) {
hipdnnHandle_t handle = handle_list.back();
handle_list.pop_back();
OF_CUDNN_CHECK(hipdnnDestroy(handle));
}
}
handle_list_map_.clear();
}
hipdnnHandle_t CudnnHandlePool::Get() {
int device_id;
OF_CUDA_CHECK(hipGetDevice(&device_id));
{
std::unique_lock<std::mutex> lock(mutex_);
std::vector<hipdnnHandle_t>& handle_list = handle_list_map_[device_id];
if (!handle_list.empty()) {
hipdnnHandle_t handle = handle_list.back();
handle_list.pop_back();
return handle;
}
}
hipdnnHandle_t handle;
OF_CUDNN_CHECK(hipdnnCreate(&handle));
return handle;
}
void CudnnHandlePool::Put(hipdnnHandle_t handle) {
int device_id;
OF_CUDA_CHECK(hipGetDevice(&device_id));
std::unique_lock<std::mutex> lock(mutex_);
std::vector<hipdnnHandle_t>& handle_list = handle_list_map_[device_id];
handle_list.push_back(handle);
}
#endif // WITH_ROCM
template<typename T>
......@@ -366,4 +427,34 @@ template const void* CudnnSPZeroPtr<float>();
template const void* CudnnSPZeroPtr<double>();
template const void* CudnnSPZeroPtr<float16>();
const void* CudnnSPOnePtr(const DataType dtype) {
if (dtype == kDouble) {
return CudnnSPOnePtr<double>();
} else if (dtype == kFloat) {
return CudnnSPOnePtr<float>();
} else if (dtype == kFloat16) {
return CudnnSPOnePtr<float16>();
} else if (dtype == kBFloat16) {
// NOTE(guoran): kBFloat16 use float OnePtr
return CudnnSPOnePtr<float>();
} else {
UNIMPLEMENTED();
}
}
const void* CudnnSPZeroPtr(const DataType dtype) {
if (dtype == kDouble) {
return CudnnSPZeroPtr<double>();
} else if (dtype == kFloat) {
return CudnnSPZeroPtr<float>();
} else if (dtype == kFloat16) {
return CudnnSPZeroPtr<float16>();
} else if (dtype == kBFloat16) {
// NOTE(guoran): kBFloat16 use float ZeroPtr
return CudnnSPZeroPtr<float>();
} else {
UNIMPLEMENTED();
}
}
} // namespace oneflow
......@@ -96,6 +96,22 @@ const void* CudnnSPOnePtr();
template<typename T>
const void* CudnnSPZeroPtr();
const void* CudnnSPOnePtr(const DataType dtype);
const void* CudnnSPZeroPtr(const DataType dtype);
class CudnnHandlePool {
public:
CudnnHandlePool() = default;
~CudnnHandlePool();
cudnnHandle_t Get();
void Put(cudnnHandle_t handle);
private:
std::mutex mutex_;
HashMap<int64_t, std::vector<cudnnHandle_t>> handle_list_map_;
};
} // namespace oneflow
#endif // WITH_CUDA
......@@ -177,6 +193,22 @@ const void* CudnnSPOnePtr();
template<typename T>
const void* CudnnSPZeroPtr();
const void* CudnnSPOnePtr(const DataType dtype);
const void* CudnnSPZeroPtr(const DataType dtype);
class CudnnHandlePool {
public:
CudnnHandlePool() = default;
~CudnnHandlePool();
hipdnnHandle_t Get();
void Put(hipdnnHandle_t handle);
private:
std::mutex mutex_;
HashMap<int64_t, std::vector<hipdnnHandle_t>> handle_list_map_;
};
} // namespace oneflow
#endif // WITH_ROCM
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/vm/instruction.h"
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/eager/blob_instruction_type.h"
#include "oneflow/core/vm/control_stream_type.h"
#include "oneflow/core/vm/stream.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/register/register_manager.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/vm/access_blob_arg_cb_phy_instr_operand.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/eager/eager_blob_object.h"
namespace oneflow {
namespace vm {
void AccessBlobByCallbackInstructionType::Compute(vm::Instruction* instruction) const {
const auto& phy_instr_operand = instruction->phy_instr_operand();
CHECK(static_cast<bool>(phy_instr_operand));
const auto* ptr =
dynamic_cast<const vm::AccessBlobArgCbPhyInstrOperand*>(phy_instr_operand.get());
CHECK_NOTNULL(ptr);
DeviceCtx* device_ctx = instruction->stream().device_ctx().get();
OfBlob ofblob(device_ctx->stream(), ptr->eager_blob_object()->blob());
ptr->callback()(reinterpret_cast<uint64_t>(&ofblob));
}
} // namespace vm
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_BLOB_INSTRUCTION_TYPE_H_
#define ONEFLOW_CORE_EAGER_BLOB_INSTRUCTION_TYPE_H_
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/common/stream_role.h"
#include "oneflow/core/common/singleton_ptr.h"
#include "oneflow/core/vm/ep_optional_event_record_status_querier.h"
#include "oneflow/core/vm/stream.h"
#include "oneflow/core/vm/ep_event.h"
#include "oneflow/core/vm/ep_device_context.h"
namespace oneflow {
namespace vm {
class AccessBlobByCallbackInstructionType final : public vm::InstructionType {
public:
AccessBlobByCallbackInstructionType() = default;
~AccessBlobByCallbackInstructionType() override = default;
std::string DebugName(const vm::Instruction& instruction) const override {
return "AccessBlobByCallback";
}
Maybe<void> Prepare(vm::Instruction* instruction) const override { return Maybe<void>::Ok(); }
void Compute(vm::Instruction* instruction) const override;
};
class EpRecordEventInstructionType final : public vm::InstructionType {
public:
EpRecordEventInstructionType() = default;
~EpRecordEventInstructionType() override = default;
InstructionFuseType fuse_type() const override { return kEnableInstructionFuseAsTailOnly; }
void InitInstructionStatus(Instruction* instruction) const override {
auto* status_buffer = instruction->mut_status_buffer();
auto* stream = instruction->mut_stream();
instruction->stream_type().InitInstructionStatus(*stream, status_buffer);
auto* ep_device_ctx = static_cast<EpDeviceCtx*>(stream->device_ctx().get());
auto* ep_event_provider = ep_device_ctx->ep_event_provider();
const auto& ep_event = CHECK_NOTNULL(ep_event_provider)->GetReusedEpEvent();
auto* data_ptr = status_buffer->mut_buffer();
EpOptionalEventRecordStatusQuerier::MutCast(data_ptr)->reset_ep_event(ep_event);
}
Maybe<void> Prepare(vm::Instruction* instruction) const override { return Maybe<void>::Ok(); }
std::string DebugName(const vm::Instruction&) const override { return "RecordEvent"; }
void Compute(vm::Instruction* instruction) const override {}
};
} // namespace vm
struct GetRecordEventInstructionType : public StreamRoleVisitor<GetRecordEventInstructionType> {
static Maybe<const vm::InstructionType*> VisitCompute(DeviceType device_type) {
return SingletonPtr<vm::EpRecordEventInstructionType>();
}
static Maybe<const vm::InstructionType*> VisitHost2Device(DeviceType device_type) {
return SingletonPtr<vm::EpRecordEventInstructionType>();
}
static Maybe<const vm::InstructionType*> VisitDevice2Host(DeviceType device_type) {
return SingletonPtr<vm::EpRecordEventInstructionType>();
}
static Maybe<const vm::InstructionType*> VisitSyncedLaunchedCommNet(DeviceType device_type) {
return SingletonPtr<vm::EpRecordEventInstructionType>();
}
static Maybe<const vm::InstructionType*> VisitAsyncedLaunchedCommNet(DeviceType device_type) {
return SingletonPtr<vm::EpRecordEventInstructionType>();
}
static Maybe<const vm::InstructionType*> VisitBarrier(DeviceType device_type) {
UNIMPLEMENTED_THEN_RETURN();
}
static Maybe<const vm::InstructionType*> VisitCriticalSection(DeviceType device_type) {
UNIMPLEMENTED_THEN_RETURN();
}
static Maybe<const vm::InstructionType*> VisitLazyJobLauncher(DeviceType device_type) {
UNIMPLEMENTED_THEN_RETURN();
}
static Maybe<const vm::InstructionType*> VisitPinnedCompute(DeviceType device_type) {
return VisitCompute(device_type);
}
};
} // namespace oneflow
#endif // ONEFLOW_CORE_EAGER_BLOB_INSTRUCTION_TYPE_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment