Unverified Commit cc92a4b4 authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #55 from ROCmSoftwarePlatform/IFU-master-2021-10-15

IFU-2021-10-15 (+ remove redundant defines + C10_CUDA_CHECK)
parents 1e0f9bc6 fec3141c
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if ((element_index + element) < batch_element_count) {
elements[i][it+element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (it < warp_iteration_limit) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < local_seq) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < local_seq) {
out[element] = elements[i][it + element] / sum[i];
} else {
out[element] = 0;
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
} else if (element_index < element_count) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
output_t *gradInput,
input_t *grad,
const input_t *output,
acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);
}
}
}
}
} // end of anonymous namespace
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_backward(
output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_forward",
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_backward",
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}
...@@ -4,15 +4,7 @@ ...@@ -4,15 +4,7 @@
#include <stdio.h> #include <stdio.h>
int SizeTToInt(size_t data) size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features);
{
if (data > std::numeric_limits<int>::max()) {
throw std::runtime_error("Invalid cast.");
}
return static_cast<int>(data);
}
size_t get_mlp_reserved_space(int batch_size, int num_layers, const int* output_features);
template <typename T> template <typename T>
size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features); size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features);
...@@ -29,7 +21,8 @@ int mlp_fp( ...@@ -29,7 +21,8 @@ int mlp_fp(
T* Y, T* Y,
T* reserved_space, T* reserved_space,
int use_bias, int use_bias,
int activation); int activation,
void* lt_workspace);
template <typename T> template <typename T>
int mlp_bp( int mlp_bp(
...@@ -68,9 +61,10 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at ...@@ -68,9 +61,10 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at
auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor // create output/workspace tensor
// TODO(deyuf): just get buffer?
auto out = at::empty({batch_size, output_features.back()}, inputs[0].type()); auto out = at::empty({batch_size, output_features.back()}, inputs[0].type());
auto reserved_space = at::empty({SizeTToInt(reserved_size)}, inputs[0].type()); auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, inputs[0].type());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] {
std::vector<scalar_t*> w_ptr; std::vector<scalar_t*> w_ptr;
...@@ -92,7 +86,8 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at ...@@ -92,7 +86,8 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at
out.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(),
reserved_space.data_ptr<scalar_t>(), reserved_space.data_ptr<scalar_t>(),
use_bias, use_bias,
activation); activation,
(void*) (lt_workspace.data_ptr<scalar_t>()));
}); });
return {out, reserved_space}; return {out, reserved_space};
...@@ -114,7 +109,6 @@ std::vector<at::Tensor> mlp_backward( ...@@ -114,7 +109,6 @@ std::vector<at::Tensor> mlp_backward(
auto batch_size = inputs[0].size(0); auto batch_size = inputs[0].size(0);
auto input_features = inputs[0].size(1); auto input_features = inputs[0].size(1);
// TODO: not creating empty tensor for it?
bool requires_grad = inputs[0].requires_grad(); bool requires_grad = inputs[0].requires_grad();
std::vector<int> output_features; std::vector<int> output_features;
...@@ -122,7 +116,6 @@ std::vector<at::Tensor> mlp_backward( ...@@ -122,7 +116,6 @@ std::vector<at::Tensor> mlp_backward(
output_features.push_back(inputs[i + 1].size(0)); output_features.push_back(inputs[i + 1].size(0));
} }
// create outputs, length of inputs // create outputs, length of inputs
// TODO: not create bias if not needed
std::vector<at::Tensor> outputs; std::vector<at::Tensor> outputs;
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now
...@@ -142,7 +135,7 @@ std::vector<at::Tensor> mlp_backward( ...@@ -142,7 +135,7 @@ std::vector<at::Tensor> mlp_backward(
get_mlp_bp_workspace_in_bytes<scalar_t>(batch_size, num_layers, output_features.data()); get_mlp_bp_workspace_in_bytes<scalar_t>(batch_size, num_layers, output_features.data());
// auto work_space = at::empty({work_size*4}, at::kByte); // auto work_space = at::empty({work_size*4}, at::kByte);
auto work_space = at::empty({SizeTToInt(work_size / sizeof(scalar_t))}, inputs[0].type()); auto work_space = at::empty({work_size / sizeof(scalar_t)}, inputs[0].type());
auto result = mlp_bp<scalar_t>( auto result = mlp_bp<scalar_t>(
inputs[0].data_ptr<scalar_t>(), inputs[0].data_ptr<scalar_t>(),
...@@ -170,3 +163,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -170,3 +163,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &mlp_forward, "MLP forward"); m.def("forward", &mlp_forward, "MLP forward");
m.def("backward", &mlp_backward, "MLP backward"); m.def("backward", &mlp_backward, "MLP backward");
} }
...@@ -10,6 +10,10 @@ ...@@ -10,6 +10,10 @@
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
#include <cublasLt.h>
#endif
// constants for fused bias+relu kernel // constants for fused bias+relu kernel
#define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block #define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block
#define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim #define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim
...@@ -249,6 +253,268 @@ cublasStatus_t mlp_gemm( ...@@ -249,6 +253,268 @@ cublasStatus_t mlp_gemm(
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif #endif
} }
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
int mlp_gemm_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
float *alpha, /* host pointer */
const at::Half* A,
int lda,
const at::Half* B,
int ldb,
float *beta, /* host pointer */
at::Half* C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
bool use_relu,
const void* bias) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDescOpaque_t operationDesc = {};
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
cublasLtMatmulPreferenceOpaque_t preference = {};
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (use_bias) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
if (use_relu) {
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
} else {
epilogue = CUBLASLT_EPILOGUE_BIAS;
}
} else {
if (use_relu) {
epilogue = CUBLASLT_EPILOGUE_RELU;
}
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
// Create matrix descriptors. Not setting any extra attributes.
status = cublasLtMatrixLayoutInit(
&Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(
&Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status = cublasLtMatmulPreferenceInit(&preference);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status = cublasLtMatmulAlgoGetHeuristic(
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (returnedResults == 0) {
status = CUBLAS_STATUS_NOT_SUPPORTED;
goto CLEANUP;
}
status = cublasLtMatmul(ltHandle,
&operationDesc,
alpha,
A,
&Adesc,
B,
&Bdesc,
beta,
C,
&Cdesc,
C,
&Cdesc,
&heuristicResult.algo,
workspace,
workspaceSize,
stream);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
int mlp_gemm_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
float *alpha, /* host pointer */
const double* A,
int lda,
const double* B,
int ldb,
float *beta, /* host pointer */
double* C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
bool use_relu,
const void* bias) {
return 1;
}
int mlp_gemm_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
float *alpha, /* host pointer */
const float *A,
int lda,
const float *B,
int ldb,
float *beta, /* host pointer */
float *C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
bool use_relu,
const void* bias) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDescOpaque_t operationDesc = {};
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
cublasLtMatmulPreferenceOpaque_t preference = {};
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (use_bias) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
if (use_relu) {
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
} else {
epilogue = CUBLASLT_EPILOGUE_BIAS;
}
} else {
if (use_relu) {
epilogue = CUBLASLT_EPILOGUE_RELU;
}
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
// Create matrix descriptors. Not setting any extra attributes.
status = cublasLtMatrixLayoutInit(
&Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(
&Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status = cublasLtMatmulPreferenceInit(&preference);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status = cublasLtMatmulAlgoGetHeuristic(
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (returnedResults == 0) {
status = CUBLAS_STATUS_NOT_SUPPORTED;
goto CLEANUP;
}
status = cublasLtMatmul(ltHandle,
&operationDesc,
alpha,
A,
&Adesc,
B,
&Bdesc,
beta,
C,
&Cdesc,
C,
&Cdesc,
&heuristicResult.algo,
workspace,
workspaceSize,
stream);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
#endif
// Bias ADD. Assume input X is [features x batch size], column major. // Bias ADD. Assume input X is [features x batch size], column major.
// Bias is one 'features' long vector, with implicit broadcast. // Bias is one 'features' long vector, with implicit broadcast.
...@@ -538,7 +804,7 @@ void get_biasAddRelu_bprop_grid_size( ...@@ -538,7 +804,7 @@ void get_biasAddRelu_bprop_grid_size(
// Get number of SMs for efficient reduction. // Get number of SMs for efficient reduction.
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
// can switch to occupancy calculation. use 4 below now for sm_70 // can switch to occupancy calculation. use 4 below now for sm_70
int max_blocks_y = num_SMs * 4 / (*grid_x); int max_blocks_y = (num_SMs * 4+(*grid_x)-1) / (*grid_x);
// block_y should be from minimal work per thread // block_y should be from minimal work per thread
int nRedSplits = (batch_size + block_y - 1) / block_y; int nRedSplits = (batch_size + block_y - 1) / block_y;
// increase number of elem per thread redcution to not launch more than enough // increase number of elem per thread redcution to not launch more than enough
...@@ -583,7 +849,7 @@ __global__ void biasAdd_bprop( ...@@ -583,7 +849,7 @@ __global__ void biasAdd_bprop(
int nidx = 0; int nidx = 0;
// Handle non-multiple of UNROLL_FACTOR residue // Handle non-multiple of UNROLL_FACTOR residue
for (; nidx < nSpan % UNROLL_FACTOR; nidx++) { for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
int row, col, flat_idx; int64_t row, col, flat_idx;
row = f; row = f;
col = nStart + nidx; col = nStart + nidx;
flat_idx = col * features + row; flat_idx = col * features + row;
...@@ -592,7 +858,7 @@ __global__ void biasAdd_bprop( ...@@ -592,7 +858,7 @@ __global__ void biasAdd_bprop(
// Handle meat of work // Handle meat of work
for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) { for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
int row, col, flat_idx; int64_t row, col, flat_idx;
row = f; row = f;
col = nStart + nidx; col = nStart + nidx;
flat_idx = col * features + row; flat_idx = col * features + row;
...@@ -865,7 +1131,6 @@ __global__ void biasAddRelu_bprop_aligned( ...@@ -865,7 +1131,6 @@ __global__ void biasAddRelu_bprop_aligned(
} }
// block result is in db_local now for all threadIdx.y == 0 // block result is in db_local now for all threadIdx.y == 0
// TODO: maybe not useful early exit here
if(gridDim.y == 1) { if(gridDim.y == 1) {
#pragma unroll #pragma unroll
for(int ii=0;ii<ILP;ii++){ for(int ii=0;ii<ILP;ii++){
...@@ -932,7 +1197,7 @@ void get_y_offsets( ...@@ -932,7 +1197,7 @@ void get_y_offsets(
} }
// Returns the reserved space (in elements) needed for the MLP // Returns the reserved space (in elements) needed for the MLP
size_t get_mlp_reserved_space(int batch_size, int num_layers, const int* output_features) { size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features) {
size_t res_space = 0; size_t res_space = 0;
// Need to store output of every intermediate MLP - size equal to output_features[i] * batch_size // Need to store output of every intermediate MLP - size equal to output_features[i] * batch_size
// for all 'i' in [0, num_layers-1) // for all 'i' in [0, num_layers-1)
...@@ -943,7 +1208,7 @@ size_t get_mlp_reserved_space(int batch_size, int num_layers, const int* output_ ...@@ -943,7 +1208,7 @@ size_t get_mlp_reserved_space(int batch_size, int num_layers, const int* output_
} }
// Returns the size of all fprop activations combined // Returns the size of all fprop activations combined
size_t get_all_activations_size(int batch_size, int num_layers, const int* output_features) { size_t get_all_activations_size(int64_t batch_size, int num_layers, const int* output_features) {
size_t acts_size = 0; size_t acts_size = 0;
for (int l = 0; l < num_layers; l++) { for (int l = 0; l < num_layers; l++) {
acts_size += output_features[l] * batch_size; acts_size += output_features[l] * batch_size;
...@@ -1064,7 +1329,8 @@ int mlp_fp( ...@@ -1064,7 +1329,8 @@ int mlp_fp(
T* Y, T* Y,
T* reserved_space, T* reserved_space,
int use_bias, int use_bias,
int activation) { int activation,
void* lt_workspace) {
T *weight, *input, *output, *bias; T *weight, *input, *output, *bias;
T *reserved_space_x, *reserved_space_y; T *reserved_space_x, *reserved_space_y;
reserved_space_x = NULL; reserved_space_x = NULL;
...@@ -1089,9 +1355,40 @@ int mlp_fp( ...@@ -1089,9 +1355,40 @@ int mlp_fp(
float one = 1.f; float one = 1.f;
float zero = 0.f; float zero = 0.f;
cublasStatus_t cublas_status; // try with cublaslt first for supported case with valid handle
// Call GEMM: fprop is Y = W'X int cublaslt_status = 1;
cublas_status = mlp_gemm( #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
if(activation < 1){
cublaslt_status = mlp_gemm_lt(
//ltHandle,
(cublasLtHandle_t)handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
ofeat,
batch_size,
ifeat,
&one,
weight,
ifeat,
input,
ifeat,
&zero,
output,
ofeat,
lt_workspace,
1 << 22,
stream,
use_bias == 1,
activation == 1,
bias);
}
#endif
// if cublaslt failed or not executed, fallback to cublas
if (cublaslt_status != 0) {
cublasStatus_t cublas_status;
// Call GEMM: fprop is Y = W'X
cublas_status = mlp_gemm(
handle, handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
...@@ -1107,39 +1404,39 @@ int mlp_fp( ...@@ -1107,39 +1404,39 @@ int mlp_fp(
output, output,
ofeat); ofeat);
if (cublas_status != CUBLAS_STATUS_SUCCESS) { if (cublas_status != CUBLAS_STATUS_SUCCESS) {
printf("GEMM fprop failed with %d\n", cublas_status); printf("GEMM fprop failed with %d\n", cublas_status);
return 1; return 1;
}
const uint &input_size = ofeat;
int num_blocks = 0;
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
// Call biasReLU
if(use_bias == 1) {
if (activation == 0) { // no activation
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
} else if (activation == 1) { // relu
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAddRelu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAddRelu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
} else if (activation == 2) { // sigmoid
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
} }
} else {
// don't need to do anything in case of no activation and no bias const uint &input_size = ofeat;
if (activation == 1) { // relu int num_blocks = 0;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0); int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
Relu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size); // Call biasReLU
} else if (activation == 2) { // sigmoid if(use_bias == 1) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0); if (activation == 0) { // no activation
Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size); cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
} else if (activation == 1) { // relu
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAddRelu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAddRelu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
} else if (activation == 2) { // sigmoid
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
}
} else {
// don't need to do anything in case of no activation and no bias
if (activation == 1) { // relu
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Relu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
} else if (activation == 2) { // sigmoid
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
}
} }
} }
// Set current output as next layer input // Set current output as next layer input
reserved_space_x = reserved_space_y; reserved_space_x = reserved_space_y;
// Set next layer output // Set next layer output
...@@ -1366,7 +1663,8 @@ template int mlp_fp<float>( ...@@ -1366,7 +1663,8 @@ template int mlp_fp<float>(
float* Y, float* Y,
float* reserved_space, float* reserved_space,
int use_bias, int use_bias,
int activation); int activation,
void* lt_workspace);
template int mlp_bp<float>( template int mlp_bp<float>(
float* X, float* X,
...@@ -1397,7 +1695,8 @@ template int mlp_fp<at::Half>( ...@@ -1397,7 +1695,8 @@ template int mlp_fp<at::Half>(
at::Half* Y, at::Half* Y,
at::Half* reserved_space, at::Half* reserved_space,
int use_bias, int use_bias,
int activation); int activation,
void* lt_workspace);
template int mlp_bp<at::Half>( template int mlp_bp<at::Half>(
at::Half* X, at::Half* X,
...@@ -1428,7 +1727,8 @@ template int mlp_fp<double>( ...@@ -1428,7 +1727,8 @@ template int mlp_fp<double>(
double* Y, double* Y,
double* reserved_space, double* reserved_space,
int use_bias, int use_bias,
int activation); int activation,
void* lt_workspace);
template int mlp_bp<double>( template int mlp_bp<double>(
double* X, double* X,
...@@ -1460,3 +1760,4 @@ template size_t get_mlp_bp_workspace_in_bytes<double>( ...@@ -1460,3 +1760,4 @@ template size_t get_mlp_bp_workspace_in_bytes<double>(
int batch_size, int batch_size,
int num_layers, int num_layers,
const int* output_features); const int* output_features);
...@@ -435,6 +435,11 @@ void multi_tensor_norm_out_cuda( ...@@ -435,6 +435,11 @@ void multi_tensor_norm_out_cuda(
// I could get rid of these by hacking the functor + multi tensor harness with persistence // I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now // logic, but keeping it simple for now
auto ret = at::empty({1}, output.options()); auto ret = at::empty({1}, output.options());
// Adding the following device guard since it happens sometimes that the
// tensors are on one device and the cuda stream is on another device which
// results in ILLEGAL MEM ACCESS error.
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
cleanup_v2<<<ntensors, 512, 0, stream>>>( cleanup_v2<<<ntensors, 512, 0, stream>>>(
output.DATA_PTR<float>(), output.DATA_PTR<float>(),
......
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template<typename in_t, typename out_t>
struct L2NormScaleFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<2>& tl,
float* output,
float* output_per_tensor,
float scale,
bool per_tensor,
int max_chunks_per_tensor)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
in_t* in = (in_t*)tl.addresses[0][tensor_loc];
in += chunk_idx*chunk_size;
out_t* out = (out_t*)tl.addresses[1][tensor_loc];
out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
__shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
in_t r_in[ILP];
for(int i = 0; i < ILP; i++)
{
vals[i] = 0.f;
r_in[i] = 0;
}
//bool finite = true;
out_t r_out[ILP];
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out))
{
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(r_in, in, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
float next = static_cast<float>(r_in[ii]);
r_out[ii] = next*scale;
vals[ii] += next*next;
//finite = finite && isfinite(r_in[ii]);
}
load_store(out, r_out, i_start, 0);
}
}
else
{
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_in[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_in[ii] = in[i];
float next = static_cast<float>(in[i]);
vals[ii] += next*next;
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
// finite = finite && isfinite(r_in[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
out[i] = r_out[ii];
}
}
}
float val = 0.f;
for(int i = 0; i < ILP; i++)
val += vals[i];
float final = reduce_block_into_lanes(s_vals, val);
if(threadIdx.x == 0)
{
if(!isfinite(final))
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] += final;
if(per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
}
}
};
// Probably better to template, but since we are not likely to support other norm
template<typename x_t>
struct MaxNormFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<1>& tl,
float* output,
float* output_per_tensor,
bool per_tensor,
int max_chunks_per_tensor)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
x += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
__shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP];
for(int i = 0; i < ILP; i++)
{
vals[i] = 0.f;
r_x[i] = 0;
}
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
{
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(r_x, x, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
float next = static_cast<float>(r_x[ii]);
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
}
}
}
else
{
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
float next = static_cast<float>(x[i]);
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
}
}
}
}
float val = 0.f;
for(int i = 0; i < ILP; i++)
val = fmaxf(fabsf(val), fabsf(vals[i]));
float final = reduce_block_into_lanes_max_op(s_vals, val);
if(threadIdx.x == 0)
{
if(!isfinite(final))
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
if(per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
}
}
};
__global__ void cleanup_v3(
float* output,
float* output_per_tensor,
float* ret,
float* ret_per_tensor,
bool per_tensor,
int max_chunks_per_tensor)
{
__shared__ float vals[512];
if(blockIdx.x == 0)
{
float val = 0;
if(threadIdx.x < 320)
val = output[threadIdx.x];
float final = reduce_block_into_lanes(vals, val);
if(threadIdx.x == 0)
*ret = sqrt(final);
}
if(per_tensor)
{
float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor;
float val = 0;
for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
val += output_this_tensor[i];
float final = reduce_block_into_lanes(vals, val);
if(threadIdx.x == 0)
ret_per_tensor[blockIdx.x] = sqrt(final);
}
}
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_scale_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float scale,
at::optional<bool> per_tensor_python)
{
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if(per_tensor)
{
for(int t = 0; t < ntensors; t++)
{
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
if(max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
}
else
{
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_scale_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_l2norm_scale_cuda",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
L2NormScaleFunctor<scalar_t_0, scalar_t_1>(),
output.DATA_PTR<float>(),
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
scale,
per_tensor,
max_chunks_per_tensor);))
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
auto ret = at::empty({1}, output.options());
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream();
cleanup_v3<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
output.DATA_PTR<float>(),
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
ret.DATA_PTR<float>(),
per_tensor ? ret_per_tensor.DATA_PTR<float>() : nullptr,
per_tensor,
max_chunks_per_tensor);
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
...@@ -34,6 +34,32 @@ ...@@ -34,6 +34,32 @@
} }
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \ switch(TYPE) \
{ \ { \
...@@ -166,6 +192,160 @@ ...@@ -166,6 +192,160 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} }
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_in = double; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_out = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
template<typename T> template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes __device__ __forceinline__ T reduce_block_into_lanes
(T *x, (T *x,
......
...@@ -81,7 +81,7 @@ def parse(): ...@@ -81,7 +81,7 @@ def parse():
help='Only run 10 iterations for profiling.') help='Only run 10 iterations for profiling.')
parser.add_argument('--deterministic', action='store_true') parser.add_argument('--deterministic', action='store_true')
parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int)
parser.add_argument('--sync_bn', action='store_true', parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.') help='enabling apex sync BN.')
......
import torch import torch
from torch.utils import cpp_extension from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from setuptools import setup, find_packages from setuptools import setup, find_packages
import subprocess import subprocess
...@@ -46,7 +46,7 @@ if not torch.cuda.is_available() and not IS_ROCM_PYTORCH: ...@@ -46,7 +46,7 @@ if not torch.cuda.is_available() and not IS_ROCM_PYTORCH:
'If you wish to cross-compile for a single specific architecture,\n' 'If you wish to cross-compile for a single specific architecture,\n'
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11: if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
else: else:
...@@ -85,11 +85,8 @@ if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -85,11 +85,8 @@ if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
if TORCH_MAJOR == 0: if TORCH_MAJOR == 0:
raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, " raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, "
"found torch.__version__ = {}".format(torch.__version__)) "found torch.__version__ = {}".format(torch.__version__))
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
if "--cpp_ext" in sys.argv: if "--cpp_ext" in sys.argv:
from torch.utils.cpp_extension import CppExtension
sys.argv.remove("--cpp_ext") sys.argv.remove("--cpp_ext")
ext_modules.append( ext_modules.append(
CppExtension('apex_C', CppExtension('apex_C',
...@@ -200,6 +197,7 @@ if "--cuda_ext" in sys.argv: ...@@ -200,6 +197,7 @@ if "--cuda_ext" in sys.argv:
'csrc/multi_tensor_scale_kernel.cu', 'csrc/multi_tensor_scale_kernel.cu',
'csrc/multi_tensor_axpby_kernel.cu', 'csrc/multi_tensor_axpby_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu', 'csrc/multi_tensor_l2norm_kernel.cu',
'csrc/multi_tensor_l2norm_scale_kernel.cu',
'csrc/multi_tensor_lamb_stage_1.cu', 'csrc/multi_tensor_lamb_stage_1.cu',
'csrc/multi_tensor_lamb_stage_2.cu', 'csrc/multi_tensor_lamb_stage_2.cu',
'csrc/multi_tensor_adam.cu', 'csrc/multi_tensor_adam.cu',
...@@ -238,6 +236,37 @@ if "--cuda_ext" in sys.argv: ...@@ -238,6 +236,37 @@ if "--cuda_ext" in sys.argv:
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fused_dense_cuda',
sources=['csrc/fused_dense.cpp',
'csrc/fused_dense_cuda.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
"""
ext_modules.append(
CUDAExtension(name='scaled_upper_triang_masked_softmax_cuda',
sources=['csrc/megatron/scaled_upper_triang_masked_softmax.cpp',
'csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='scaled_masked_softmax_cuda',
sources=['csrc/megatron/scaled_masked_softmax.cpp',
'csrc/megatron/scaled_masked_softmax_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda'] + version_dependent_macros}))
"""
if "--bnp" in sys.argv or "--cuda_ext" in sys.argv: if "--bnp" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
...@@ -338,18 +367,14 @@ if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')) ...@@ -338,18 +367,14 @@ if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h'))
generator_flag = ['-DOLD_GENERATOR'] generator_flag = ['-DOLD_GENERATOR']
if "--fast_layer_norm" in sys.argv: if "--fast_layer_norm" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_layer_norm") sys.argv.remove("--fast_layer_norm")
from torch.utils.cpp_extension import BuildExtension if CUDA_HOME is None and not IS_ROCM_PYTORCH:
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11: if int(bare_metal_major) >= 11:
cc_flag.append('-gencode') cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80') cc_flag.append('arch=compute_80,code=sm_80')
...@@ -368,7 +393,45 @@ if "--fast_layer_norm" in sys.argv: ...@@ -368,7 +393,45 @@ if "--fast_layer_norm" in sys.argv:
'-Iapex/contrib/csrc/layer_norm', '-Iapex/contrib/csrc/layer_norm',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")]))
if "--fmha" in sys.argv:
sys.argv.remove("--fmha")
if CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--fmha was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) < 11:
raise RuntimeError("--fmha only supported on SM80")
ext_modules.append(
CUDAExtension(name='fmhalib',
sources=[
'apex/contrib/csrc/fmha/fmha_api.cpp',
'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu',
],
extra_compile_args={'cxx': ['-O3',
] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_80,code=sm_80',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc"), os.path.join(this_dir, "apex/contrib/csrc/fmha/src")]))
if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv: if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
...@@ -471,6 +534,40 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -471,6 +534,40 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha})) 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
if "--transducer" in sys.argv:
sys.argv.remove("--transducer")
if CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--transducer was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
ext_modules.append(
CUDAExtension(name='transducer_joint_cuda',
sources=['apex/contrib/csrc/transducer/transducer_joint.cpp',
'apex/contrib/csrc/transducer/transducer_joint_kernel.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc': ['-O3'] + version_dependent_macros},
include_dirs=[os.path.join(this_dir, 'csrc'), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")]))
ext_modules.append(
CUDAExtension(name='transducer_loss_cuda',
sources=['apex/contrib/csrc/transducer/transducer_loss.cpp',
'apex/contrib/csrc/transducer/transducer_loss_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
if "--fast_bottleneck" in sys.argv:
sys.argv.remove("--fast_bottleneck")
if CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--fast_bottleneck was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
ext_modules.append(
CUDAExtension(name='fast_bottleneck',
sources=['apex/contrib/csrc/bottleneck/bottleneck.cpp'],
include_dirs=[os.path.join(this_dir, 'apex/contrib/csrc/cudnn-frontend/include')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag}))
if "--cuda_ext" in sys.argv: if "--cuda_ext" in sys.argv:
sys.argv.remove("--cuda_ext") sys.argv.remove("--cuda_ext")
...@@ -489,5 +586,6 @@ setup( ...@@ -489,5 +586,6 @@ setup(
description='PyTorch Extensions written by NVIDIA', description='PyTorch Extensions written by NVIDIA',
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass=cmdclass, cmdclass=cmdclass,
#cmdclass={'build_ext': BuildExtension} if ext_modules else {},
extras_require=extras, extras_require=extras,
) )
import torch
from torch.optim import Optimizer
import math
import apex
import unittest
from test_fused_optimizer import TestFusedOptimizer
from itertools import product
class Novograd(Optimizer):
"""
Implements Novograd algorithm.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.95, 0))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging: gradient averaging
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
"""
def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8,
weight_decay=0, grad_averaging=False, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay,
grad_averaging=grad_averaging,
amsgrad=amsgrad)
super(Novograd, self).__init__(params, defaults)
def __setstate__(self, state):
super(Novograd, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Sparse gradients are not supported.')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
norm = torch.sum(torch.pow(grad, 2))
if exp_avg_sq == 0:
exp_avg_sq.copy_(norm)
else:
exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
grad.div_(denom)
if group['weight_decay'] != 0:
grad.add_(p.data, alpha=group['weight_decay'])
if group['grad_averaging']:
grad.mul_(1 - beta1)
exp_avg.mul_(beta1).add_(grad)
p.data.add_(exp_avg, alpha=-group['lr'])
return loss
class TestFusedNovoGrad(TestFusedOptimizer):
def __init__(self, *args, **kwargs):
super(TestFusedNovoGrad, self).__init__(*args, **kwargs)
# The options for NovoGrad and FusedNovoGrad are very specific if they
# are expected to behave the same.
self.options = {'lr':1e-3, 'betas':(0.95, 0), 'eps':1e-8,
'weight_decay':0, 'grad_averaging':False, 'amsgrad':False}
self.tst_options = {'lr':1e-3, 'betas':(0.95, 0), 'eps':1e-8,
'weight_decay':0, 'grad_averaging':False, 'amsgrad':False,
'bias_correction':False, 'reg_inside_moment':True,
'norm_type':2, 'init_zero':False, 'set_grad_none':True}
self.ref_optim = Novograd
self.fused_optim = apex.optimizers.FusedNovoGrad
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:1", "cuda:0")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
torch.cuda.synchronize()
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
tensors = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
tensors, self.options, self.tst_options
)
for _ in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
if __name__ == '__main__':
unittest.main()
...@@ -2,9 +2,11 @@ import unittest ...@@ -2,9 +2,11 @@ import unittest
import os import os
import random import random
import math
import torch import torch
import apex import apex
from itertools import product from itertools import product
from torch.optim import Optimizer
class TestFusedOptimizer(unittest.TestCase): class TestFusedOptimizer(unittest.TestCase):
def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7): def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
...@@ -16,28 +18,29 @@ class TestFusedOptimizer(unittest.TestCase): ...@@ -16,28 +18,29 @@ class TestFusedOptimizer(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def gen_param_optim(self, tensors, options, apex_only=False): def gen_param_optim(self, tensors, options, tst_options=None):
# Adding this to make backward compatible with existing tests. Just in
# case "tst_options" are not provided, it gets a copy of options
# which contains the parameters for the reference optimizer
if tst_options == None:
tst_options = options
ref_param = [] ref_param = []
tst_param = [] tst_param = []
for tensor in tensors: for tensor in tensors:
if apex_only: ref_param.append(torch.nn.Parameter(tensor.clone()))
ref_param.append(torch.nn.Parameter(tensor.clone().float()))
else:
ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone())) tst_param.append(torch.nn.Parameter(tensor.clone()))
if apex_only: ref_optim = self.ref_optim(ref_param, **options)
ref_optim = self.fused_optim(ref_param, **options) tst_optim = self.fused_optim(tst_param, **tst_options)
else:
ref_optim = self.ref_optim(ref_param, **options)
tst_optim = self.fused_optim(tst_param, **options)
return (ref_param, tst_param, ref_optim, tst_optim) return (ref_param, tst_param, ref_optim, tst_optim)
def gen_grad(self, ref_param, tst_param, apex_only=False): def gen_grad(self, ref_param, tst_param):
for p_ref, p_tst in zip(ref_param, tst_param): for p_ref, p_tst in zip(ref_param, tst_param):
p_tst.grad = torch.rand_like(p_tst) p_ref.grad = torch.rand_like(p_ref)
p_ref.grad = p_tst.grad.detach().float() if apex_only else p_tst.grad p_tst.grad = p_ref.grad
def gen_mixed_grad(self, ref_param, tst_param, scale=1.0): def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
half_grads = [] half_grads = []
...@@ -46,11 +49,9 @@ class TestFusedOptimizer(unittest.TestCase): ...@@ -46,11 +49,9 @@ class TestFusedOptimizer(unittest.TestCase):
p_ref.grad = half_grads[-1].float() / scale p_ref.grad = half_grads[-1].float() / scale
return half_grads return half_grads
def get_max_diff(self, ref_param, tst_param, apex_only=False): def get_max_diff(self, ref_param, tst_param):
max_abs_diff = max_rel_diff = 0 max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param): for p_ref, p_tst in zip(ref_param, tst_param):
if apex_only:
p_tst = p_tst.float()
max_abs_diff_p = (p_ref - p_tst).abs().max().item() max_abs_diff_p = (p_ref - p_tst).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item() max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()
...@@ -59,21 +60,29 @@ class TestFusedOptimizer(unittest.TestCase): ...@@ -59,21 +60,29 @@ class TestFusedOptimizer(unittest.TestCase):
return max_abs_diff, max_rel_diff return max_abs_diff, max_rel_diff
def gen_single_type_test(self, param_type=torch.float, apex_only=False, device='cuda'): def gen_single_type_test(self, param_type=torch.float, device='cuda'):
nelem = 278011 nelem = 278011
# Some ref and test optimizers may require different set of options.
# This is a quick workaround to add that functionality while making
# minimum changes in existing code.
# If there is no "tst_options" field provided, safe to initialize
# the test optimizer with the parameters of reference optimizer.
if not hasattr(self, 'tst_options'):
self.tst_options = self.options
tensor = torch.rand(nelem, dtype=param_type, device=device) tensor = torch.rand(nelem, dtype=param_type, device=device)
ref_param, tst_param, ref_optim, tst_optim = \ ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], self.options, apex_only=apex_only) self.gen_param_optim([tensor], self.options, self.tst_options)
for i in range(self.iters): for i in range(self.iters):
self.gen_grad(ref_param, tst_param, apex_only=apex_only) self.gen_grad(ref_param, tst_param)
ref_optim.step() ref_optim.step()
tst_optim.step() tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param, apex_only=apex_only) max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff) self.assertLessEqual(max_abs_diff, self.max_abs_diff)
if not apex_only: self.assertLessEqual(max_rel_diff, self.max_rel_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
class TestFusedAdam(TestFusedOptimizer): class TestFusedAdam(TestFusedOptimizer):
...@@ -91,14 +100,6 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -91,14 +100,6 @@ class TestFusedAdam(TestFusedOptimizer):
def test_half(self): def test_half(self):
self.gen_single_type_test(param_type=torch.float16) self.gen_single_type_test(param_type=torch.float16)
# Compares bfloat16 computation against float32 as gold standard.
# Uses apex optimizers(controlled by apex_only flag) for both types.
# Doesn't use upstream optimizer like other tests as they seem to be
# numerically unstable for half types
def test_bfloat16(self):
self.max_abs_diff = 1e-2
self.gen_single_type_test(param_type=torch.bfloat16, apex_only=True)
@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required") @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
def test_multi_device(self): def test_multi_device(self):
devices = ("cuda:0", "cuda:1") devices = ("cuda:0", "cuda:1")
...@@ -279,8 +280,5 @@ class TestFusedSGD(TestFusedOptimizer): ...@@ -279,8 +280,5 @@ class TestFusedSGD(TestFusedOptimizer):
with torch.cuda.device(current_dev): with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev) self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -26,4 +26,4 @@ for test_dir in test_dirs: ...@@ -26,4 +26,4 @@ for test_dir in test_dirs:
if not result.wasSuccessful(): if not result.wasSuccessful():
errcode = 1 errcode = 1
sys.exit(errcode) sys.exit(errcode)
\ No newline at end of file
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from apex.transformer.tensor_parallel.tests.commons import set_random_seed
from apex.transformer.tensor_parallel.tests.commons import IdentityLayer
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
from apex.transformer.tensor_parallel.tests import global_vars
global_vars.set_global_variables()
def torch_cross_entropy(batch_size, seq_length, vocab_size,
logits_scale, seed):
set_random_seed(seed)
identity = IdentityLayer((batch_size, seq_length, vocab_size),
scale=logits_scale).cuda()
logits = identity()
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
loss = F.cross_entropy(logits.view(-1, logits.size()[-1]),
target.view(-1),
reduction='none').view_as(target).mean()
loss.backward()
return loss, identity.weight.grad
def tensor_sharded_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed):
set_random_seed(seed)
identity = IdentityLayer((batch_size, seq_length, vocab_size), scale=logits_scale).cuda()
logits = identity()
logits_parallel = tensor_parallel.scatter_to_tensor_model_parallel_region(logits)
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
loss.backward()
return loss, identity.weight.grad
def test_cross_entropy(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cross entropy with model parallel size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
batch_size = 13
seq_length = 17
vocab_size_per_partition = 11
logits_scale = 1000.0
vocab_size = vocab_size_per_partition * tensor_model_parallel_size
seed = 1234
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed)
loss_mpu, grad_mpu = tensor_sharded_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed)
error = loss_torch.sub_(loss_mpu).abs().max()
print(' max error in loss on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = grad_torch.sub_(grad_mpu).abs().max()
print(' max error in grad on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test cross entropy')
test_cross_entropy(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import operator
import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import data as data_utils
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def test_broadcast_data(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing broadcast_data with model parallel size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
torch.manual_seed(1234 + parallel_state.get_data_parallel_rank())
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
key_size_t = {
'key1': [7, 11],
'key2': [8, 2, 1],
'key3': [13],
'key4': [5, 1, 2],
'key5': [5, 12],
}
keys = list(key_size_t.keys())
data = {}
data_t = {}
for key in key_size_t:
data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000)
data_t[key] = data[key].clone()
data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
data_t['keyX'] = data['keyX'].clone()
if parallel_state.get_tensor_model_parallel_rank() != 0:
data = None
data_utils._check_data_types(keys, data_t, torch.int64)
key_size, key_numel, \
total_numel = data_utils._build_key_size_numel_dictionaries(keys, data)
for key in keys:
assert key_size[key] == key_size_t[key]
total_numel_t = 0
for key in keys:
target_size = functools.reduce(operator.mul, key_size_t[key], 1)
assert key_numel[key] == target_size
total_numel_t += target_size
assert total_numel == total_numel_t
data_b = data_utils.broadcast_data(keys, data, torch.int64)
for key in keys:
tensor = data_t[key].cuda()
assert data_b[key].sub(tensor).abs().max() == 0
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test test broadcast data')
test_broadcast_data(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def test_initialize_model_parallel(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing initialize_model_parallel with size {} ...'.format(
tensor_model_parallel_size))
tensor_model_parallel_size_ = min(
tensor_model_parallel_size,
torch.distributed.get_world_size(),
)
assert not parallel_state.model_parallel_is_initialized()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_)
assert parallel_state.model_parallel_is_initialized()
# Checks.
def check(group, world_size, rank):
assert world_size == torch.distributed.get_world_size(group=group)
assert rank == torch.distributed.get_rank(group=group)
# Model parallel.
world_size = tensor_model_parallel_size_
rank = torch.distributed.get_rank() % tensor_model_parallel_size_
assert world_size == parallel_state.get_tensor_model_parallel_world_size()
assert rank == parallel_state.get_tensor_model_parallel_rank()
check(parallel_state.get_tensor_model_parallel_group(), world_size, rank)
# Data parallel.
world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_
rank = torch.distributed.get_rank() // tensor_model_parallel_size
assert world_size == parallel_state.get_data_parallel_world_size()
assert rank == parallel_state.get_data_parallel_rank()
check(parallel_state.get_data_parallel_group(), world_size, rank)
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
if torch.distributed.get_rank() == 0:
print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format(
tensor_model_parallel_size_))
tensor_model_parallel_size = min(
tensor_model_parallel_size_,
torch.distributed.get_world_size(),
)
assert not parallel_state.model_parallel_is_initialized()
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
assert parallel_state.model_parallel_is_initialized()
# Checks
src_rank = torch.distributed.get_rank() - parallel_state.get_tensor_model_parallel_rank()
assert parallel_state.get_tensor_model_parallel_src_rank() == src_rank
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test initialize model parallel')
test_initialize_model_parallel(tensor_model_parallel_size)
print_separator('test model parallel source rank')
test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.init as init
from torch.nn.parameter import Parameter
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import layers
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.tensor_parallel.tests.commons import set_random_seed
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
class IdentityLayer3D(torch.nn.Module):
def __init__(self, m, n, k):
super(IdentityLayer3D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n, k))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def test_parallel_embedding(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing parallel embedding with model parallel size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
batch_size = 17
seq_length = 23
vocab_size = 48
hidden_size = 16
seed = 1236
set_random_seed(123)
input_data = torch.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size).cuda()
loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
set_random_seed(seed)
embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()
output = embedding_original(input_data)
loss_original = torch.mul(output, loss_weight).sum()
loss_original.backward()
set_random_seed(seed)
embedding_parallel = layers.ParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_parallel(input_data)
loss_parallel = torch.mul(output, loss_weight).sum()
loss_parallel.backward()
set_random_seed(seed)
embedding_vocab_parallel = layers.VocabParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_vocab_parallel(input_data)
loss_vocab_parallel = torch.mul(output, loss_weight).sum()
loss_vocab_parallel.backward()
torch.distributed.barrier()
error = loss_parallel.sub(loss_original).abs()
print(' error in loss (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
torch.distributed.barrier()
error = loss_vocab_parallel.sub(loss_original).abs()
print(' error in loss (vocab parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
hidden_size // tensor_model_parallel_size,
1)[parallel_state.get_tensor_model_parallel_rank()]
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
print(' error in grad (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
vocab_size // tensor_model_parallel_size,
0)[parallel_state.get_tensor_model_parallel_rank()]
error = embedding_vocab_parallel.weight.grad.sub(
weight_grad_orig).abs().max()
print(' error in grad (vocab parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_initialize_affine_weight(tensor_model_parallel_size, device):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing initialize_affine_weight with model parallel '
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
# ---------------
# Column parallel
# ---------------
weight = torch.empty(output_size_coeff, input_size)
set_random_seed(seed)
if device == 'cpu':
layers._initialize_affine_weight_cpu(weight, output_size, input_size,
output_size_coeff, 0,
torch.nn.init.normal_,
params_dtype=global_vars.get_args().params_dtype,
)
else:
layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 0)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = parallel_state.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, output_size_coeff,
dim=0)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(' column parallel max error (should be zero) on global rank '
'{}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# ------------
# Row parallel
# ------------
weight = torch.empty(output_size, input_size_coeff)
set_random_seed(seed)
if device == 'cpu':
layers._initialize_affine_weight_cpu(
weight, output_size, input_size, input_size_coeff, 1, torch.nn.init.normal_,
params_dtype=global_vars.get_args().params_dtype)
else:
layers._initialize_affine_weight_gpu(weight, torch.nn.init.normal_, 1)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = parallel_state.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, input_size_coeff,
dim=1)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(' row parallel max error (should be zero) on global rank '
'{}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
class IdentityLayer2D(torch.nn.Module):
def __init__(self, m, n):
super(IdentityLayer2D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def test_column_parallel_linear(tensor_model_parallel_size):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing ColumnParallelLinear with model parallel '
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
linear_layer = layers.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output, _ = linear_layer(input_)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = parallel_state.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdA on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
my_dLdb = torch.split(dLdb, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdb on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdX on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def test_column_parallel_linear_with_async_allreduce_autocast(tensor_model_parallel_size):
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer3D(batch_size, batch_size, input_size).cuda()
linear_layer = layers.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
).cuda()
assert linear_layer.async_tensor_model_parallel_allreduce or tensor_model_parallel_size == 1
# Forward
for dtype in autocast_dtypes:
loss_weight = torch.randn([batch_size, output_size]).cuda()
with torch.cuda.amp.autocast(dtype=dtype):
output, _ = linear_layer(identity_layer())
loss = torch.mul(output, loss_weight).sum()
assert output.dtype == dtype
# Backward
loss.backward()
torch.distributed.barrier()
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def test_column_parallel_linear_with_async_allreduce_custom_amp(tensor_model_parallel_size):
dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
for dtype in dtypes:
# Network
identity_layer = IdentityLayer3D(batch_size, batch_size, input_size).to(device="cuda", dtype=dtype)
linear_layer = layers.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
).to(device="cuda", dtype=dtype)
# Forward
loss_weight = torch.randn([batch_size, output_size]).cuda()
output, _ = linear_layer(identity_layer())
loss = torch.mul(output, loss_weight).sum()
loss.backward()
torch.distributed.barrier()
assert output.dtype == dtype
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def test_row_parallel_linear(tensor_model_parallel_size):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing RowParallelLinear with model parallel '
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
linear_layer = layers.RowParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output, _ = linear_layer(input_)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = parallel_state.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, input_size_coeff,
dim=1)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdA on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdb on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdX on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size,
sequence_length):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads
# Network
identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda()
attention_layer = parallel_state.BertParallelSelfAttention(hidden_size, num_att_heads,
dropout_prob).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward
input_ = identity_layer()
output = attention_layer(input_, attention_mask)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
rank = parallel_state.get_tensor_model_parallel_rank()
parallel_state.destroy_model_parallel()
return rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer
def test_parallel_self_attention(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelSelfAttention with model parallel '
'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
dropout_prob = 0.0 # has to be zero
batch_size = 5
sequence_length = 13
rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 = parallel_self_attention(
1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer = parallel_self_attention(
tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size
error = loss_1.sub(loss).abs().max()
torch.distributed.barrier()
print(' loss error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
my_lin_grad_list = torch.split(
attention_layer_1.query_key_value.weight.grad,
hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size]
my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
error = my_lin_grad.sub(
attention_layer.query_key_value.weight.grad).abs().max()
torch.distributed.barrier()
print(' weight gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
error = identity_layer_1.weight.grad.sub(
identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' input gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length):
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads
intermediate_size = 4 * hidden_size
# Network
identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda()
transformer_layer = parallel_state.BertParallelTransformerLayer(
hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
torch.nn.functional.relu, 1.0e-5).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward
input_ = identity_layer()
output = transformer_layer(input_, attention_mask)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
rank = parallel_state.get_tensor_model_parallel_rank()
parallel_state.destroy_model_parallel()
return rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer
def test_parallel_transformer_layer(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelTransformerLayer with model parallel '
'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
batch_size = 5
sequence_length = 13
rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \
transformer_layer_1, identity_layer_1 = parallel_transformer(
1, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer = parallel_transformer(
tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
error = loss_1.sub(loss).abs().max()
torch.distributed.barrier()
print(' loss error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-5, 'error: {}'.format(error)
error = identity_layer_1.weight.grad.sub(
identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' input gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-5, 'error: {}'.format(error)
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
exceptions = []
print_separator('test initialize affine weight cpu')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
try:
test_initialize_affine_weight(tensor_model_parallel_size, 'cpu')
except Exception as e:
exceptions.append(f"test_initialize_affine_weight-cpu with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
# Reset groups
parallel_state.destroy_model_parallel()
print_separator('test initialize affine weight gpu')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
try:
test_initialize_affine_weight(tensor_model_parallel_size, 'gpu')
except Exception as e:
exceptions.append(f"test_initialize_affine_weight-gpu with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
# Deleted, replaced with vocab parallel embedding?
#tensor_model_parallel_size = 1
#while tensor_model_parallel_size <= world_size:
# print_separator('test parallel embedding')
# test_parallel_embedding(tensor_model_parallel_size)
# tensor_model_parallel_size *= 2
print_separator('test column-parallel linear')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
try:
test_column_parallel_linear(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_column_parallel_linear with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
print_separator('test row-parallel linear')
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
try:
test_row_parallel_linear(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_row_parallel_linear with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
print_separator("test ColumnParallelLinearWithAsyncAllreduce - autocast")
tensor_model_parallel_size = 2
while tensor_model_parallel_size <= world_size:
try:
test_column_parallel_linear_with_async_allreduce_autocast(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_column_parallel_linear_with_async_allreduce_autocast with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
print_separator("test ColumnParallelLinearWithAsyncAllreduce - custom AMP")
tensor_model_parallel_size = 2
while tensor_model_parallel_size <= world_size:
try:
test_column_parallel_linear_with_async_allreduce_custom_amp(tensor_model_parallel_size)
except Exception as e:
exceptions.append(f"test_column_parallel_linear_with_async_allreduce_custom_amp with tensor model parallel size of {tensor_model_parallel_size} failed: {str(e)}")
# Reset groups
parallel_state.destroy_model_parallel()
break
else:
tensor_model_parallel_size *= 2
if exceptions:
raise RuntimeError("\n".join(exceptions))
# Deleted
#print_separator('test parallel self-attention')
#tensor_model_parallel_size = 1
#while tensor_model_parallel_size <= world_size:
# test_parallel_self_attention(tensor_model_parallel_size)
# tensor_model_parallel_size *= 2
#Deleted because PararallelTransformerLayer no longer exists
# print_separator('test parallel transformer')
# tensor_model_parallel_size = 1
# while tensor_model_parallel_size <= world_size:
# test_parallel_transformer_layer(tensor_model_parallel_size)
# tensor_model_parallel_size *= 2
import torch
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel import mappings
from apex.transformer.tensor_parallel.tests import global_vars
global_vars.set_global_variables()
def test__reduce(args, tensor_model_parallel_size):
print("Testing reduction size =", tensor_model_parallel_size)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
assert torch.equal(
mappings._reduce(torch.full((10, 10, 10, 10), (50))),
torch.full((10, 10, 10, 10), 50 * tensor_model_parallel_size),
)
parallel_state.destroy_model_parallel()
print("Passed!")
def test__split(args, tensor_model_parallel_size):
print("Testing splitting size =", tensor_model_parallel_size)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
listy = []
for i in range(tensor_model_parallel_size):
listy.append(torch.randn(10, 1))
x = torch.cat(tuple(listy), 1)
out = mappings._split(x)
assert torch.equal(out, listy[parallel_state.get_tensor_model_parallel_rank()])
parallel_state.destroy_model_parallel()
print("Passed!")
def test__gather(args, tensor_model_parallel_size):
print("Testing gathering size =", tensor_model_parallel_size)
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
assert torch.equal(
mappings._gather(torch.tensor([parallel_state.get_tensor_model_parallel_rank()])),
torch.tensor(list(range(tensor_model_parallel_size))),
)
parallel_state.destroy_model_parallel()
print("Passed!")
if __name__ == "__main__":
initialize_distributed()
world_size = torch.distributed.get_world_size()
args = global_vars.get_args()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test__reduce(args, tensor_model_parallel_size)
test__split(args, tensor_model_parallel_size)
test__gather(args, tensor_model_parallel_size)
tensor_model_parallel_size *= 2
print(">> passed the test :-)")
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.tensor_parallel.tests import global_vars
from apex.transformer.tensor_parallel.tests.commons import print_separator
from apex.transformer.tensor_parallel.tests.commons import initialize_distributed
from apex.transformer.tensor_parallel.tests.commons import TEST_SUCCESS_MESSAGE
global_vars.set_global_variables()
def test_set_cuda_rng_state(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing set_rng_state with size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
size = 123
seed = 1234
torch.cuda.manual_seed(seed)
tensor = torch.cuda.FloatTensor(size)
# Get the state
rng_state = torch.cuda.get_rng_state()
rng_state_copy = rng_state.clone()
# Do some stuff.
for _ in range(5):
torch.randn(size, out=tensor)
result_1 = tensor.clone()
assert rng_state.sub(rng_state_copy).max() == 0
assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0
# State should be different.
new_rng_state = torch.cuda.get_rng_state()
max_diff = new_rng_state.sub(rng_state).max()
print(' max diff in rng state (should be non-zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), max_diff))
assert max_diff > 0
# Reset the rng state and do the same stuff.
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
result_2 = tensor.clone()
# Results should be the same
error = result_2.sub(result_1).abs().max()
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Input state should have remained intact.
error = rng_state.sub(rng_state_copy).max()
print(' max error in rng state (should be zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), error))
assert error == 0
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
def test_cuda_rng_tracker(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cuda rng tracker with size {} ...'.
format(tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
seed_1 = 1234
seed_2 = 4321
size = [12, 21]
tensor = torch.cuda.FloatTensor(size)
# Set to seed_1 and generate two tensors.
torch.cuda.manual_seed(seed_1)
torch.randn(size, out=tensor)
target_11 = tensor.clone()
torch.randn(size, out=tensor)
target_12 = tensor.clone()
# Set to seed_2 and generate two tensors.
torch.cuda.manual_seed(seed_2)
torch.randn(size, out=tensor)
target_21 = tensor.clone()
torch.randn(size, out=tensor)
target_22 = tensor.clone()
# Now if we interleave seed_1 and seed_2,
# we should still get the same tensors
torch.cuda.manual_seed(seed_1)
tensor_parallel.random.get_cuda_rng_tracker().add('test', seed_2)
torch.randn(size, out=tensor)
result_11 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_21 = tensor.clone()
torch.randn(size, out=tensor)
result_12 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_22 = tensor.clone()
diff = result_11.sub(result_21).abs().max()
diff = min(diff, result_12.sub(result_22).abs().max())
print(' max diff in generated tensors (should be non-zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
assert diff > 1.0e-6
error = max(result_11.sub(target_11).abs().max(),
result_12.sub(target_12).abs().max())
error = max(error, result_21.sub(target_21).abs().max())
error = max(error, result_22.sub(target_22).abs().max())
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset the tracker
tensor_parallel.random.get_cuda_rng_tracker().reset()
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print(
'> testing model parallel cuda manual seed with size {} ...'.format(
tensor_model_parallel_size))
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
tensor_parallel.random.model_parallel_cuda_manual_seed(12345)
assert torch.cuda.initial_seed() == 12345
with tensor_parallel.random.get_cuda_rng_tracker().fork():
assert (
torch.cuda.initial_seed() ==
12345 + 2718 + parallel_state.get_tensor_model_parallel_rank()
)
# Reset the tracker
tensor_parallel.random.get_cuda_rng_tracker().reset()
# Reset groups
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test set rng state')
test_set_cuda_rng_state(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test cuda rng tracker')
test_cuda_rng_tracker(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test model parallel cuda manual seed')
test_model_parallel_cuda_manual_seed(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment