Commit da3f0934 authored by zhuwenwen's avatar zhuwenwen
Browse files

delete unused files

parent c4dd1fd4
// !!! This is a file automatically generated by hipify!!!
/*This code from NVIDIA Megatron:
* with minor changes. */
#include <hip/hip_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/*This code from NVIDIA Megatron:
* with minor changes. */
#pragma once
#include <assert.h>
#include <hip/hip_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if TORCH_HIP_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if ((element_index + element) < batch_element_count) {
elements[i][it+element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (it < warp_iteration_limit) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < local_seq) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < local_seq) {
out[element] = elements[i][it + element] / sum[i];
} else {
out[element] = 0;
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
} else if (element_index < element_count) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
output_t *gradInput,
input_t *grad,
const input_t *output,
acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);
}
}
}
}
} // end of anonymous namespace
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_backward(
output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
hipLaunchKernelGGL(( scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>)
, dim3(blocks), dim3(threads), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
// !!! This is a file automatically generated by hipify!!!
/*This code from NVIDIA Megatron:
* with minor changes. */
#include <ATen/ATen.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#ifndef COLOSSAL_HIP
#include <cuda_profiler_api.h>
#endif
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
#include "../../hip_native/csrc/scaled_upper_triang_masked_softmax.h"
#include "../../hip_native/csrc/type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_forward",
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_backward",
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <ATen/ATen.h>
#include "../../hip_native/csrc/compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes(T *x,
T val,
int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if (tid < i)
x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32)
{
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
#ifdef COLOSSAL_HIP
final = final + __shfl_down(final, i);
#else
final = final + __shfl_down_sync(0xffffffff, final, i);
#endif
}
if (share_result)
{
if (tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x,
T val,
int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if (tid < i)
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32)
{
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
#ifdef COLOSSAL_HIP
final = fmaxf(fabsf(final), fabsf(__shfl_down(final, i)));
#else
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
#endif
}
if (share_result)
{
if (tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
from .option import set_jit_fusion_options
from .bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference
from .bias_gelu import bias_gelu_impl
__all__ = [
"bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl",
"set_jit_fusion_options"
]
import torch
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
@torch.jit.script
def bias_dropout_add_fused_train(x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False)
import torch
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
\ No newline at end of file
import torch
JIT_OPTIONS_SET = False
def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options.
"""
# LSG: the latest pytorch and CUDA versions may not support
# the following jit settings
global JIT_OPTIONS_SET
if JIT_OPTIONS_SET == False:
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# nvfuser
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._debug_set_autodiff_subgraph_inlining(False)
else:
# legacy pytorch fuser
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
JIT_OPTIONS_SET = True
from typing import List
from .logging import DistributedLogger
import logging
__all__ = ['get_dist_logger', 'DistributedLogger']
def get_dist_logger(name='colossalai'):
"""Get logger instance based on name. The DistributedLogger will create singleton instances,
which means that only one logger instance is created per name.
:param name: name of the logger, name must be unique
:type name: str
:return: a distributed logger instance
:rtype: :class:`colossalai.logging.DistributedLogger`
"""
return DistributedLogger.get_instance(name=name)
def disable_existing_loggers(except_loggers: List[str] = ['colossalai']):
"""Set the level of existing loggers to `WARNING`.
:param except_loggers: loggers in this `list` will be ignored when disabling, defaults to ['colossalai']
:type except_loggers: list, optional
"""
for log_name in logging.Logger.manager.loggerDict.keys():
if log_name not in except_loggers:
logging.getLogger(log_name).setLevel(logging.WARNING)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import colossalai
import logging
from pathlib import Path
from typing import Union
from colossalai.context.parallel_mode import ParallelMode
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=_FORMAT)
class DistributedLogger:
"""This is a distributed event logger class essentially based on :class:`logging`.
:param name: The name of the logger
:type name: str
"""
__instances = dict()
@staticmethod
def get_instance(name: str):
"""Get the unique single logger instance based on name.
:param name: The name of the logger
:type name: str
:return: A DistributedLogger object
:rtype: DistributedLogger
"""
if name in DistributedLogger.__instances:
return DistributedLogger.__instances[name]
else:
logger = DistributedLogger(name=name)
return logger
def __init__(self, name):
if name in DistributedLogger.__instances:
raise Exception(
'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger')
else:
self._name = name
self._logger = logging.getLogger(name)
DistributedLogger.__instances[name] = self
@staticmethod
def _check_valid_logging_level(level: str):
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
def set_level(self, level: str):
"""Set the logging level
:param level: Can only be INFO, DEBUG, WARNING and ERROR
:type level: str
"""
self._check_valid_logging_level(level)
self._logger.setLevel(getattr(logging, level))
def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None):
"""Save the logs to file
:param path: The file to save the log
:type path: A string or pathlib.Path object
:param mode: The mode to write log into the file
:type mode: str
:param level: Can only be INFO, DEBUG, WARNING and ERROR
:type level: str
:param suffix: The suffix string of log's name
:type suffix: str
"""
assert isinstance(path, (str, Path)), \
f'expected argument path to be type str or Path, but got {type(path)}'
self._check_valid_logging_level(level)
if isinstance(path, str):
path = Path(path)
# create log directory
path.mkdir(parents=True, exist_ok=True)
# set the default file name if path is a directory
if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL):
rank = 0
else:
rank = colossalai.core.global_context.get_global_rank()
if suffix is not None:
log_file_name = f'rank_{rank}_{suffix}.log'
else:
log_file_name = f'rank_{rank}.log'
path = path.joinpath(log_file_name)
# add file handler
file_handler = logging.FileHandler(path, mode)
file_handler.setLevel(getattr(logging, level))
formatter = logging.Formatter(_FORMAT)
file_handler.setFormatter(formatter)
self._logger.addHandler(file_handler)
def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
if ranks is None:
getattr(self._logger, level)(message)
else:
local_rank = colossalai.core.global_context.get_local_rank(parallel_mode)
if local_rank in ranks:
getattr(self._logger, level)(message)
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Log an info message.
:param message: The message to be logged
:type message: str
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode`
:param ranks: List of parallel ranks
:type ranks: list
"""
self._log('info', message, parallel_mode, ranks)
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Log a warning message.
:param message: The message to be logged
:type message: str
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode`
:param ranks: List of parallel ranks
:type ranks: list
"""
self._log('warning', message, parallel_mode, ranks)
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Log a debug message.
:param message: The message to be logged
:type message: str
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode`
:param ranks: List of parallel ranks
:type ranks: list
"""
self._log('debug', message, parallel_mode, ranks)
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Log an error message.
:param message: The message to be logged
:type message: str
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode`
:param ranks: List of parallel ranks
:type ranks: list
"""
self._log('error', message, parallel_mode, ranks)
from .layer import *
from .loss import *
from .lr_scheduler import *
from .metric import *
from .model import *
from .optimizer import *
import math
import warnings
from torch import Tensor
import torch.nn as nn
def zeros_():
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.zeros_(tensor)
return initializer
def ones_():
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.ones_(tensor)
return initializer
def uniform_(a: float = 0., b: float = 1.):
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.uniform_(tensor, a, b)
return initializer
def normal_(mean: float = 0., std: float = 1.):
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.normal_(tensor, mean, std)
return initializer
def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = 2.):
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.trunc_normal_(tensor, mean, std, a, b)
return initializer
def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
if 0 in tensor.shape:
warnings.warn("Initializing zero-element tensors is a no-op")
return tensor
if mode == 'fan_in':
assert fan_in is not None, 'Fan_in is not provided.'
fan = fan_in
elif mode == 'fan_out':
assert fan_out is not None, 'Fan_out is not provided.'
fan = fan_out
else:
raise ValueError(f'Invalid initialization mode \'{mode}\'')
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
bound = math.sqrt(3.) * std
return nn.init.uniform_(tensor, -bound, bound)
return initializer
def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
if 0 in tensor.shape:
warnings.warn("Initializing zero-element tensors is a no-op")
return tensor
if mode == 'fan_in':
assert fan_in is not None, 'Fan_in is not provided.'
fan = fan_in
elif mode == 'fan_out':
assert fan_out is not None, 'Fan_out is not provided.'
fan = fan_out
else:
raise ValueError(f'Invalid initialization mode \'{mode}\'')
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
return nn.init.normal_(tensor, 0, std)
return initializer
def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1.):
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
fan = fan_in
if fan_out is not None:
fan += fan_out
std = gain * math.sqrt(scale / float(fan))
bound = a * std
return nn.init.uniform_(tensor, -bound, bound)
return initializer
def xavier_normal_(scale: float = 2., gain: float = 1.):
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
fan = fan_in
if fan_out is not None:
fan += fan_out
std = gain * math.sqrt(scale / float(fan))
return nn.init.normal_(tensor, 0., std)
return initializer
def lecun_uniform_():
# adapted from jax.nn.initializers
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
var = 1.0 / fan_in
bound = math.sqrt(3 * var)
return nn.init.uniform_(tensor, -bound, bound)
return initializer
def lecun_normal_():
# adapted from jax.nn.initializers
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
std = math.sqrt(1.0 / fan_in)
return nn.init.trunc_normal_(tensor, std=std / .87962566103423978)
return initializer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment