Unverified Commit 8dd52b07 authored by Tri Dao's avatar Tri Dao Committed by GitHub
Browse files

Merge pull request #55 from ajfadam/main

remove numpy dependency
parents 88dc2040 4e38df05
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
import numpy as np
import torch
import torch.nn.functional as F
......@@ -15,7 +13,7 @@ class IndexFirstAxis(torch.autograd.Function):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = np.prod(other_shape)
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(rearrange(input, 'b ... -> b (...)'), 0,
......@@ -71,7 +69,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = np.prod(other_shape)
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
output = input[indices]
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment