"git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "1b7672937b9ba679b1ad077c1760203aefb77226"
Unverified Commit 56add6d5 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] support eval in mevo (#884)

- During eval, we will fallback to just output projection without fusing
- added unit test to ensure the shape is correct
parent e6acdcc3
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Any, Tuple from typing import Any, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -430,7 +430,14 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO ...@@ -430,7 +430,14 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
# nlprob, then sum over all tokens. # nlprob, then sum over all tokens.
return -prob.sum() return -prob.sum()
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore def eval_forward(self, input: torch.Tensor) -> torch.Tensor:
"""Eval time forward that doesn't fuse the softmax and NLL Loss kernels."""
return torch.matmul(input, self.proj_weight.T)
def forward(self, input: torch.Tensor, target: Optional[torch.Tensor]) -> torch.Tensor: # type: ignore
if not self.training and target is None:
return self.eval_forward(input)
if DEBUG and dist.is_initialized() and dist.get_rank() == 0: if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
cur_mem = round(torch.cuda.memory_allocated() / 1024 / 1024) cur_mem = round(torch.cuda.memory_allocated() / 1024 / 1024)
mem = round(torch.cuda.max_memory_allocated() / 1024 / 1024) mem = round(torch.cuda.max_memory_allocated() / 1024 / 1024)
......
...@@ -27,6 +27,17 @@ _dense_out = {} # type: ignore ...@@ -27,6 +27,17 @@ _dense_out = {} # type: ignore
_dense_grad = {} # type: ignore _dense_grad = {} # type: ignore
@skip_if_no_cuda
def test_mevo_eval():
"""Test eval mode without target tensor"""
weight = torch.nn.Linear(3, 4).cuda().weight
input = torch.rand(1, 5, 3).cuda()
k = MEVO(weight)
k.eval()
out = k(input, None)
assert out.shape == (1, 5, 4)
@skip_if_no_cuda @skip_if_no_cuda
def test_mevo(): def test_mevo():
"""Test the MEVO kernel by itself.""" """Test the MEVO kernel by itself."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment