Unverified Commit 4e38df05 authored by Antoine Adam's avatar Antoine Adam Committed by GitHub
Browse files

remove numpy dependency

According to the `setup.py` file, only dependencies are torch and einops. But the `bert_padding.py` file requires `numpy` only to multiply the elements of a `torch.Size` object. This change aims at allowing the use of FlashAttention without numpy.
parent 88dc2040
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -15,7 +13,7 @@ class IndexFirstAxis(torch.autograd.Function): ...@@ -15,7 +13,7 @@ class IndexFirstAxis(torch.autograd.Function):
ctx.save_for_backward(indices) ctx.save_for_backward(indices)
assert input.ndim >= 2 assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = np.prod(other_shape) second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices] # return input[indices]
return torch.gather(rearrange(input, 'b ... -> b (...)'), 0, return torch.gather(rearrange(input, 'b ... -> b (...)'), 0,
...@@ -71,7 +69,7 @@ class IndexFirstAxisResidual(torch.autograd.Function): ...@@ -71,7 +69,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
ctx.save_for_backward(indices) ctx.save_for_backward(indices)
assert input.ndim >= 2 assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = np.prod(other_shape) second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
output = input[indices] output = input[indices]
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment