Unverified Commit 56add6d5 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] support eval in mevo (#884)

- During eval, we will fallback to just output projection without fusing
- added unit test to ensure the shape is correct
parent e6acdcc3
......@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
from typing import Any, Tuple
from typing import Any, Optional, Tuple
import torch
from torch import nn
......@@ -430,7 +430,14 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
# nlprob, then sum over all tokens.
return -prob.sum()
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore
def eval_forward(self, input: torch.Tensor) -> torch.Tensor:
"""Eval time forward that doesn't fuse the softmax and NLL Loss kernels."""
return torch.matmul(input, self.proj_weight.T)
def forward(self, input: torch.Tensor, target: Optional[torch.Tensor]) -> torch.Tensor: # type: ignore
if not self.training and target is None:
return self.eval_forward(input)
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
cur_mem = round(torch.cuda.memory_allocated() / 1024 / 1024)
mem = round(torch.cuda.max_memory_allocated() / 1024 / 1024)
......
......@@ -27,6 +27,17 @@ _dense_out = {} # type: ignore
_dense_grad = {} # type: ignore
@skip_if_no_cuda
def test_mevo_eval():
"""Test eval mode without target tensor"""
weight = torch.nn.Linear(3, 4).cuda().weight
input = torch.rand(1, 5, 3).cuda()
k = MEVO(weight)
k.eval()
out = k(input, None)
assert out.shape == (1, 5, 4)
@skip_if_no_cuda
def test_mevo():
"""Test the MEVO kernel by itself."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment