Unverified Commit 89e1ae5f authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] add CosFace paper's LMCL to MEVO (#916)



* [feat] add CosFace paper's LMCL to MEVO

- added baseline algorithm to the reference kernel
- added MEVO version of LMCL
- added unit test to verify it is correct with respect to the reference as well as its memory usage

* updated changelog
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 8ba649e1
...@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.4.6] - TBD ## [0.4.6] - TBD
### Added
- CosFace's LMCL is added to MEVO. This is a loss function that is suitable
for large number of prediction target classes. It added normalization,
class separation margins and feature vector scaling to the standard
output projection + cross-entropy loss. MEVO supported this with its
memory saving techniques so that peak GPU memory is much reduced. [#916]
## [0.4.5] - 2022-01-14 ## [0.4.5] - 2022-01-14
......
...@@ -68,6 +68,10 @@ def get_data( ...@@ -68,6 +68,10 @@ def get_data(
class BaselineSoftmax(nn.Module): class BaselineSoftmax(nn.Module):
"""Baseline softmax that does an output linear projection and a softmax. """Baseline softmax that does an output linear projection and a softmax.
We also support LMCL (Large Margin Cosine Loss) from the CosFace paper. See
more detailed comment in the MEVO class below.
This is intended to be used with an embedding layer with shared weights. This is intended to be used with an embedding layer with shared weights.
Args: Args:
...@@ -77,10 +81,23 @@ class BaselineSoftmax(nn.Module): ...@@ -77,10 +81,23 @@ class BaselineSoftmax(nn.Module):
Unused. It is here to make kernel init easier with MEVO. Unused. It is here to make kernel init easier with MEVO.
log_softmax (bool): log_softmax (bool):
If True, use log_softmax instead of softmax. If True, use log_softmax instead of softmax.
margin (float):
Used in LMCL (when scale != None). See MEVO comments for
more details.
scale (Optional[float]):
Used in LMCL. If scale is None, LMCL is turned off. See
MEVO comments for more details.
""" """
def __init__(self, proj_weight: nn.Parameter, tile_factor: int = 0, log_softmax: bool = True): def __init__(
self,
proj_weight: nn.Parameter,
tile_factor: int = 0,
log_softmax: bool = True,
margin: float = 0.35,
scale: Optional[float] = None,
):
super().__init__() super().__init__()
out_dim, in_dim = proj_weight.shape out_dim, in_dim = proj_weight.shape
assert "cuda" in str(proj_weight.device), "weight should be on GPU" assert "cuda" in str(proj_weight.device), "weight should be on GPU"
...@@ -92,6 +109,27 @@ class BaselineSoftmax(nn.Module): ...@@ -92,6 +109,27 @@ class BaselineSoftmax(nn.Module):
assert self.fc.weight.dtype in [torch.float16, torch.float32], self.fc.weight.dtype assert self.fc.weight.dtype in [torch.float16, torch.float32], self.fc.weight.dtype
self.fp16 = self.fc.weight.dtype == torch.float16 self.fp16 = self.fc.weight.dtype == torch.float16
self.log_softmax = log_softmax self.log_softmax = log_softmax
self.margin = margin
self.scale = scale
def lmcl_pre_softmax(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# normalize feature and fc layer before multiplication
# n: number of features (tokens)
# k: number of classes (vocab size)
# c: hidden dimension (d_model)
x = F.normalize(input, dim=1)
w = F.normalize(self.fc.weight, dim=1)
logits = torch.einsum("nc,kc->nk", x, w)
# add margin
row_ind = torch.arange(x.shape[0], dtype=torch.long).to(x.device)
col_ind = target
logits[row_ind, col_ind] -= self.margin
# add scale
logits *= self.scale
return logits
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore
"""Forward function that computes softmax output with the input and target.""" """Forward function that computes softmax output with the input and target."""
...@@ -100,6 +138,9 @@ class BaselineSoftmax(nn.Module): ...@@ -100,6 +138,9 @@ class BaselineSoftmax(nn.Module):
input, target = _reshape_inputs(input, target) input, target = _reshape_inputs(input, target)
if self.fp16: if self.fp16:
assert input.dtype == torch.float16 assert input.dtype == torch.float16
if self.scale is not None:
x = self.lmcl_pre_softmax(input, target)
else:
x = self.fc(input) x = self.fc(input)
# Note that we do softmax in FP32, which is important for numerical stability. # Note that we do softmax in FP32, which is important for numerical stability.
if self.log_softmax: if self.log_softmax:
...@@ -119,8 +160,15 @@ class BaselineSoftmaxNllLoss(BaselineSoftmax): ...@@ -119,8 +160,15 @@ class BaselineSoftmaxNllLoss(BaselineSoftmax):
This class is used for testing and benchmarking. This class is used for testing and benchmarking.
""" """
def __init__(self, proj_weight: nn.Parameter, tile_factor: int = 0, log_softmax: bool = True): def __init__(
super().__init__(proj_weight, tile_factor, log_softmax) self,
proj_weight: nn.Parameter,
tile_factor: int = 0,
log_softmax: bool = True,
margin: float = 0.35,
scale: Optional[float] = None,
):
super().__init__(proj_weight, tile_factor, log_softmax, margin, scale)
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore
"""Forward that directly compute the loss.""" """Forward that directly compute the loss."""
...@@ -131,17 +179,48 @@ class BaselineSoftmaxNllLoss(BaselineSoftmax): ...@@ -131,17 +179,48 @@ class BaselineSoftmaxNllLoss(BaselineSoftmax):
return F.nll_loss(x, target, reduction="sum") return F.nll_loss(x, target, reduction="sum")
def lmcl_matmul(
i: torch.Tensor, w: torch.Tensor, tgt: torch.Tensor, w_idx: int, margin: float, scale: Optional[float]
) -> torch.Tensor:
"""LMCL variation of matmul with normalization, margin and scale."""
# normalize and matmul
logits = torch.matmul(F.normalize(i, dim=1), F.normalize(w, dim=1).T)
# add margin using a mask since tgt might be out of the the weight split's range.
mask = torch.arange(w_idx * w.shape[0], (w_idx + 1) * w.shape[0], dtype=torch.long, device=i.device).expand(
i.shape[0], -1
)
logits[mask == tgt.reshape(-1, 1)] -= margin
# add scale
logits *= scale
return logits
class GetMaxFunction(torch.autograd.Function): class GetMaxFunction(torch.autograd.Function):
"""Custom checkpointed function to get max-per-token from an input and a weight""" """Custom checkpointed function to get max-per-token from an input and a weight"""
@staticmethod @staticmethod
def get_max(i: torch.Tensor, w: torch.Tensor, full_precision: bool) -> torch.Tensor: def get_max(
i: torch.Tensor,
w: torch.Tensor,
tgt: torch.Tensor,
w_idx: int,
full_precision: bool,
margin: float,
scale: Optional[float],
) -> torch.Tensor:
""" """
Throughout this code: Throughout this code:
i: input data with shape = (split-of-tokens, d_model) i: input data with shape = (split-of-tokens, d_model)
w: weight data with shape = (split-of-vocabs, d_model) w: weight data with shape = (split-of-vocabs, d_model)
tgt: target prediction data with shape = (split-of-tokens,)
""" """
if scale is not None:
_m = lmcl_matmul(i, w, tgt, w_idx, margin, scale)
else:
_m = torch.matmul(i, w.T) _m = torch.matmul(i, w.T)
if full_precision: if full_precision:
_m = _m.float() _m = _m.float()
...@@ -153,6 +232,7 @@ class GetMaxFunction(torch.autograd.Function): ...@@ -153,6 +232,7 @@ class GetMaxFunction(torch.autograd.Function):
ctx: Any, ctx: Any,
i: torch.Tensor, i: torch.Tensor,
w: torch.Tensor, w: torch.Tensor,
tgt: torch.Tensor,
kernel_obj: "MemoryEfficientVocabOutput", kernel_obj: "MemoryEfficientVocabOutput",
w_idx: int, w_idx: int,
w_split_size: int, w_split_size: int,
...@@ -161,7 +241,7 @@ class GetMaxFunction(torch.autograd.Function): ...@@ -161,7 +241,7 @@ class GetMaxFunction(torch.autograd.Function):
"""Forward function that computes the max, without saving activations.""" """Forward function that computes the max, without saving activations."""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0: if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG max fwd") print("DEBUG max fwd")
ctx.save_for_backward(i, w) ctx.save_for_backward(i, w, tgt)
ctx.kernel_obj = kernel_obj ctx.kernel_obj = kernel_obj
ctx.w_idx = w_idx ctx.w_idx = w_idx
ctx.w_split_size = w_split_size ctx.w_split_size = w_split_size
...@@ -171,7 +251,7 @@ class GetMaxFunction(torch.autograd.Function): ...@@ -171,7 +251,7 @@ class GetMaxFunction(torch.autograd.Function):
# The activations will be recomputed in backward below and freed # The activations will be recomputed in backward below and freed
# immediately after use. This saves the overall GPU peak memory of this layer. # immediately after use. This saves the overall GPU peak memory of this layer.
with torch.no_grad(): with torch.no_grad():
return GetMaxFunction.get_max(i, w, kernel_obj.fp_max) return GetMaxFunction.get_max(i, w, tgt, w_idx, kernel_obj.fp_max, kernel_obj.margin, kernel_obj.scale)
@staticmethod @staticmethod
def backward(ctx: Any, *args: Any) -> Any: def backward(ctx: Any, *args: Any) -> Any:
...@@ -186,7 +266,7 @@ class GetMaxFunction(torch.autograd.Function): ...@@ -186,7 +266,7 @@ class GetMaxFunction(torch.autograd.Function):
assert ctx.kernel_obj.proj_weight.grad is not None assert ctx.kernel_obj.proj_weight.grad is not None
# Get saved i and w. # Get saved i and w.
i, w = ctx.saved_tensors i, w, tgt = ctx.saved_tensors
assert i.requires_grad assert i.requires_grad
assert w.requires_grad assert w.requires_grad
# We use ``detach()'' to ensure the backward call below does not # We use ``detach()'' to ensure the backward call below does not
...@@ -199,7 +279,9 @@ class GetMaxFunction(torch.autograd.Function): ...@@ -199,7 +279,9 @@ class GetMaxFunction(torch.autograd.Function):
# Forward + backward again. # Forward + backward again.
with torch.enable_grad(): with torch.enable_grad():
# This saves the activations. # This saves the activations.
maxs = GetMaxFunction.get_max(i, w, ctx.kernel_obj.fp_max) maxs = GetMaxFunction.get_max(
i, w, tgt, ctx.w_idx, ctx.kernel_obj.fp_max, ctx.kernel_obj.margin, ctx.kernel_obj.scale
)
# This will use the activations and free them immediately. # This will use the activations and free them immediately.
torch.autograd.backward(maxs, *args) torch.autograd.backward(maxs, *args)
...@@ -208,14 +290,26 @@ class GetMaxFunction(torch.autograd.Function): ...@@ -208,14 +290,26 @@ class GetMaxFunction(torch.autograd.Function):
with torch.no_grad(): with torch.no_grad():
grads = torch.split(ctx.kernel_obj.proj_weight.grad, ctx.w_split_size) grads = torch.split(ctx.kernel_obj.proj_weight.grad, ctx.w_split_size)
grads[ctx.w_idx].add_(w.grad) grads[ctx.w_idx].add_(w.grad)
return i.grad, None, None, None, None, None return i.grad, None, None, None, None, None, None
class GetSumFunction(torch.autograd.Function): class GetSumFunction(torch.autograd.Function):
"""Custom checkpointed function to get sum-per-token from an input and a weight.""" """Custom checkpointed function to get sum-per-token from an input and a weight."""
@staticmethod @staticmethod
def get_sum(i: torch.Tensor, w: torch.Tensor, maxs: torch.Tensor, full_precision: bool) -> torch.Tensor: def get_sum(
i: torch.Tensor,
w: torch.Tensor,
tgt: torch.Tensor,
maxs: torch.Tensor,
w_idx: int,
full_precision: bool,
margin: float,
scale: Optional[float],
) -> torch.Tensor:
if scale is not None:
_s = lmcl_matmul(i, w, tgt, w_idx, margin, scale)
else:
_s = torch.matmul(i, w.T) _s = torch.matmul(i, w.T)
if full_precision: if full_precision:
_s = _s.float() _s = _s.float()
...@@ -227,6 +321,7 @@ class GetSumFunction(torch.autograd.Function): ...@@ -227,6 +321,7 @@ class GetSumFunction(torch.autograd.Function):
ctx: Any, ctx: Any,
i: torch.Tensor, i: torch.Tensor,
w: torch.Tensor, w: torch.Tensor,
tgt: torch.Tensor,
maxs: torch.Tensor, maxs: torch.Tensor,
kernel_obj: "MemoryEfficientVocabOutput", kernel_obj: "MemoryEfficientVocabOutput",
w_idx: int, w_idx: int,
...@@ -236,13 +331,15 @@ class GetSumFunction(torch.autograd.Function): ...@@ -236,13 +331,15 @@ class GetSumFunction(torch.autograd.Function):
"""Forward function that computes the sum, without saving activations.""" """Forward function that computes the sum, without saving activations."""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0: if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG sum fwd") print("DEBUG sum fwd")
ctx.save_for_backward(i, w, maxs) ctx.save_for_backward(i, w, tgt, maxs)
ctx.kernel_obj = kernel_obj ctx.kernel_obj = kernel_obj
ctx.w_idx = w_idx ctx.w_idx = w_idx
ctx.w_split_size = w_split_size ctx.w_split_size = w_split_size
assert split_dim == 0 assert split_dim == 0
with torch.no_grad(): with torch.no_grad():
return GetSumFunction.get_sum(i, w, maxs, kernel_obj.fp_sum) return GetSumFunction.get_sum(
i, w, tgt, maxs, w_idx, kernel_obj.fp_sum, kernel_obj.margin, kernel_obj.scale
)
@staticmethod @staticmethod
def backward(ctx: Any, *args: Any) -> Any: def backward(ctx: Any, *args: Any) -> Any:
...@@ -257,7 +354,7 @@ class GetSumFunction(torch.autograd.Function): ...@@ -257,7 +354,7 @@ class GetSumFunction(torch.autograd.Function):
assert ctx.kernel_obj.proj_weight.grad is not None assert ctx.kernel_obj.proj_weight.grad is not None
# Get saved i, w, and maxs. # Get saved i, w, and maxs.
i, w, maxs = ctx.saved_tensors i, w, tgt, maxs = ctx.saved_tensors
assert i.requires_grad assert i.requires_grad
assert w.requires_grad assert w.requires_grad
assert maxs.requires_grad assert maxs.requires_grad
...@@ -267,7 +364,9 @@ class GetSumFunction(torch.autograd.Function): ...@@ -267,7 +364,9 @@ class GetSumFunction(torch.autograd.Function):
# Forward + backward again. # Forward + backward again.
with torch.enable_grad(): with torch.enable_grad():
sums = GetSumFunction.get_sum(i, w, maxs, ctx.kernel_obj.fp_sum) sums = GetSumFunction.get_sum(
i, w, tgt, maxs, ctx.w_idx, ctx.kernel_obj.fp_sum, ctx.kernel_obj.margin, ctx.kernel_obj.scale
)
torch.autograd.backward(sums, *args) torch.autograd.backward(sums, *args)
# Accumulate the grads. # Accumulate the grads.
...@@ -275,22 +374,35 @@ class GetSumFunction(torch.autograd.Function): ...@@ -275,22 +374,35 @@ class GetSumFunction(torch.autograd.Function):
with torch.no_grad(): with torch.no_grad():
grads = torch.split(ctx.kernel_obj.proj_weight.grad, ctx.w_split_size) grads = torch.split(ctx.kernel_obj.proj_weight.grad, ctx.w_split_size)
grads[ctx.w_idx].add_(w.grad) grads[ctx.w_idx].add_(w.grad)
return i.grad, None, maxs.grad, None, None, None, None return i.grad, None, None, maxs.grad, None, None, None, None
class TargetScoreFunction(torch.autograd.Function): class TargetScoreFunction(torch.autograd.Function):
"""Custom checkpointed function to compute the target score.""" """Custom checkpointed function to compute the target score."""
@staticmethod @staticmethod
def get_target_score(i: torch.Tensor, w: torch.Tensor, target: torch.Tensor, full_precision: bool) -> torch.Tensor: def get_target_score(
i: torch.Tensor,
w: torch.Tensor,
target: torch.Tensor,
full_precision: bool,
margin: float,
scale: Optional[float],
) -> torch.Tensor:
tokens, d_model = i.shape tokens, d_model = i.shape
assert d_model == w.shape[1] assert d_model == w.shape[1]
tw = w.gather(dim=0, index=target.reshape(target.shape[0], 1).expand(target.shape[0], d_model)) tw = w.gather(dim=0, index=target.reshape(target.shape[0], 1).expand(target.shape[0], d_model))
assert tw.shape == (tokens, d_model) assert tw.shape == (tokens, d_model)
if scale is not None:
target_score = F.normalize(i, dim=1) * F.normalize(tw, dim=1)
else:
target_score = i * tw target_score = i * tw
if full_precision: if full_precision:
target_score = target_score.float() target_score = target_score.float()
target_score = target_score.sum(dim=1) # sum into target scores with shape (tokens,) target_score = target_score.sum(dim=1) # sum into target scores with shape (tokens,)
if scale is not None:
target_score -= margin
target_score *= scale
return target_score return target_score
@staticmethod @staticmethod
...@@ -303,7 +415,9 @@ class TargetScoreFunction(torch.autograd.Function): ...@@ -303,7 +415,9 @@ class TargetScoreFunction(torch.autograd.Function):
ctx.save_for_backward(i, w, target) ctx.save_for_backward(i, w, target)
ctx.kernel_obj = kernel_obj ctx.kernel_obj = kernel_obj
with torch.no_grad(): with torch.no_grad():
x = TargetScoreFunction.get_target_score(i, w, target, kernel_obj.fp_target) x = TargetScoreFunction.get_target_score(
i, w, target, kernel_obj.fp_target, kernel_obj.margin, kernel_obj.scale
)
return x return x
@staticmethod @staticmethod
...@@ -319,7 +433,9 @@ class TargetScoreFunction(torch.autograd.Function): ...@@ -319,7 +433,9 @@ class TargetScoreFunction(torch.autograd.Function):
i = i.detach().requires_grad_(True) i = i.detach().requires_grad_(True)
w = w.detach().requires_grad_(True) w = w.detach().requires_grad_(True)
with torch.enable_grad(): with torch.enable_grad():
scores = TargetScoreFunction.get_target_score(i, w, target, ctx.kernel_obj.fp_target) scores = TargetScoreFunction.get_target_score(
i, w, target, ctx.kernel_obj.fp_target, ctx.kernel_obj.margin, ctx.kernel_obj.scale
)
torch.autograd.backward(scores, *args) torch.autograd.backward(scores, *args)
if ctx.kernel_obj.proj_weight.grad is not None: if ctx.kernel_obj.proj_weight.grad is not None:
# This means we accumulate full grad between iters. Not memory efficient. # This means we accumulate full grad between iters. Not memory efficient.
...@@ -388,18 +504,55 @@ class BackwardTrigger(nn.Module): ...@@ -388,18 +504,55 @@ class BackwardTrigger(nn.Module):
class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
"""Fused fc + softmax + nll_loss in a tiled fashion. """Fused fc + softmax + nll_loss in a tiled fashion.
This uses much less memory but is quite a bit slower. MEVO uses much less memory but is quite a bit slower.
MEVO also implements the LMCL (Large Margin Cosine Loss) function introduced by
highly cited
`CosFace: Large Margin Cosine Loss for Deep Face Recognition [Wang et al.]`_.
.. _`CosFace: Large Margin Cosine Loss for Deep Face Recognition [Wang et al.]`: https://arxiv.org/abs/1801.09414
LMCL can be turned on using the ``margin`` and ``scale`` parameters below. These
hyperparameters most likely require tuning, depending on the number of classes etc.
MEVO LMCL can be suitable for face recognition and image retrieval tasks, esp. when
the number prediction target classes is large. MEVO is slower but can use much
less GPU memory in that case, which enables training with larger batches. We
hope this is helpful but we strongly recommend users (AI researchers
and engineers) to carefully consider their applications of this technology. This
types of technology should not be used by small group of people exclusively to
potentially harm the general public.
Args: Args:
proj_weight (nn.Parameter): proj_weight (nn.Parameter):
Sharing this weight with an embedding layer. Sharing this weight with an embedding layer.
tile_factor (int): tile_factor (int):
Number of splits to use on the input sequence and vocab dimensions. Number of splits to use on the input sequence and vocab dimensions.
Default: 16
reduction (str): reduction (str):
Reduction OP (sum or mean). Reduction OP (sum or mean).
Default: sum
margin (float):
Hyperparameter of the separation margin between classes. See the
appendix of the CosFace paper for a formula on how to compute its
value properly. The default value is unlikely to be suitable in all
cases.
Default: 0.35
scale (Optional[float]):
Hyperparameter of the feature-vector-scaling for LMCL. When not
supplied, LMCL is turned off. See the appendix of the CosFace paper for
a formula on how to compute its value properly.
Default: None
""" """
def __init__(self, proj_weight: nn.Parameter, tile_factor: int = 16, reduction: str = "sum"): def __init__(
self,
proj_weight: nn.Parameter,
tile_factor: int = 16,
reduction: str = "sum",
margin: float = 0.35,
scale: Optional[float] = None,
):
super().__init__() super().__init__()
self.proj_weight = proj_weight self.proj_weight = proj_weight
# TODO (Min): these two factors doesn't have to be the same. More tuning can be done. # TODO (Min): these two factors doesn't have to be the same. More tuning can be done.
...@@ -410,12 +563,14 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO ...@@ -410,12 +563,14 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
self.log_softmax = True self.log_softmax = True
self.reduction = reduction self.reduction = reduction
assert self.reduction in ["sum", "mean"] assert self.reduction in ["sum", "mean"]
self.margin = margin
self.scale = scale
self.trigger = BackwardTrigger(self.proj_weight) self.trigger = BackwardTrigger(self.proj_weight)
if DEBUG and dist.is_initialized() and dist.get_rank() == 0: if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print( print(
f"DEBUG cfg tf_in={self.tf_in} tf_w={self.tf_w} fp_max={self.fp_max} " f"DEBUG cfg tf_in={self.tf_in} tf_w={self.tf_w} fp_max={self.fp_max} "
f"fp_sum={self.fp_sum} fp_target={self.fp_target} log_softmax={self.log_softmax} " f"fp_sum={self.fp_sum} fp_target={self.fp_target} log_softmax={self.log_softmax} "
f"reduction={self.reduction}" f"reduction={self.reduction} margin={self.margin} scale={self.scale}"
) )
def get_target_nlprob( def get_target_nlprob(
...@@ -432,6 +587,8 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO ...@@ -432,6 +587,8 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
def eval_forward(self, input: torch.Tensor) -> torch.Tensor: def eval_forward(self, input: torch.Tensor) -> torch.Tensor:
"""Eval time forward that doesn't fuse the softmax and NLL Loss kernels.""" """Eval time forward that doesn't fuse the softmax and NLL Loss kernels."""
# Margin, scaling and normalization of LMCL does not apply to eval time as far as
# I can tell. Therefore, we just do a matmul like the standard output layer.
return torch.matmul(input, self.proj_weight.T) return torch.matmul(input, self.proj_weight.T)
def forward(self, input: torch.Tensor, target: Optional[torch.Tensor]) -> torch.Tensor: # type: ignore def forward(self, input: torch.Tensor, target: Optional[torch.Tensor]) -> torch.Tensor: # type: ignore
...@@ -449,8 +606,10 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO ...@@ -449,8 +606,10 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
input, target = _reshape_inputs(input, target) input, target = _reshape_inputs(input, target)
tokens, d_model = input.shape tokens, d_model = input.shape
(t2,) = target.shape
vocab, d2 = self.proj_weight.shape vocab, d2 = self.proj_weight.shape
assert d_model == d2 assert d_model == d2, f"incorrect shape {d_model} vs {d2}"
assert tokens == t2, f"incorrect shape {tokens} vs {t2}"
split_dim = 0 split_dim = 0
input_split_size = _next_power_of_2_or_max(tokens // self.tf_in, tokens) input_split_size = _next_power_of_2_or_max(tokens // self.tf_in, tokens)
weight_split_size = _next_power_of_2_or_max(vocab // self.tf_w, vocab) weight_split_size = _next_power_of_2_or_max(vocab // self.tf_w, vocab)
...@@ -458,12 +617,16 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO ...@@ -458,12 +617,16 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
weight = self.trigger() weight = self.trigger()
weights = torch.split(weight, weight_split_size, split_dim) weights = torch.split(weight, weight_split_size, split_dim)
targets = tuple([torch.Tensor()] * len(inputs))
if self.scale is not None:
targets = torch.split(target, input_split_size, split_dim)
# Get maxs # Get maxs
maxs = [] maxs = []
for i in inputs: for i, tgt in zip(inputs, targets):
m = None # max with (tokens_tile,) shape m = None # max with (tokens_tile,) shape
for w_idx, w in enumerate(weights): for w_idx, w in enumerate(weights):
_m = GetMaxFunction.apply(i, w, self, w_idx, weight_split_size, split_dim) _m = GetMaxFunction.apply(i, w, tgt, self, w_idx, weight_split_size, split_dim)
if m is None: if m is None:
m = _m m = _m
else: else:
...@@ -475,10 +638,10 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO ...@@ -475,10 +638,10 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
# Get sums. # Get sums.
sums = [] sums = []
for idx, i in enumerate(inputs): for i, tgt, debase_max in zip(inputs, targets, maxs):
s = None # sum with (tokens_tile,) shape s = None # sum with (tokens_tile,) shape
for w_idx, w in enumerate(weights): for w_idx, w in enumerate(weights):
_s = GetSumFunction.apply(i, w, maxs[idx], self, w_idx, weight_split_size, split_dim) _s = GetSumFunction.apply(i, w, tgt, debase_max, self, w_idx, weight_split_size, split_dim)
if s is None: if s is None:
s = _s s = _s
else: else:
......
...@@ -38,10 +38,15 @@ def test_mevo_eval(): ...@@ -38,10 +38,15 @@ def test_mevo_eval():
assert out.shape == (1, 5, 4) assert out.shape == (1, 5, 4)
# Note for the lmcl_scale, overly large value, like 64 for small shape input
# will cause inf/nan in mevo. Larger scale value is only needed for large shape inputs.
@skip_if_no_cuda @skip_if_no_cuda
def test_mevo(): @pytest.mark.parametrize("lmcl_scale", [None, 8])
"""Test the MEVO kernel by itself.""" def test_mevo(lmcl_scale):
"""Test the MEVO kernel in a single process (no DDP/FSDP)."""
# Set seed and reset peak mem so that peak measure below is correct.
torch.random.manual_seed(os.getpid()) torch.random.manual_seed(os.getpid())
torch.cuda.reset_peak_memory_stats()
shape = ((5, 3), (3, 7)) shape = ((5, 3), (3, 7))
# Turn on large data for local testing. # Turn on large data for local testing.
large = False large = False
...@@ -50,11 +55,11 @@ def test_mevo(): ...@@ -50,11 +55,11 @@ def test_mevo():
print("\nshapes are", shape) print("\nshapes are", shape)
input, weight, target = get_data(shape, dtype=torch.float16) input, weight, target = get_data(shape, dtype=torch.float16)
k = MEVO(weight, tile_factor=16) k = MEVO(weight, tile_factor=16, scale=lmcl_scale)
o = k(input, target) o = k(input, target)
o.backward() o.backward()
print(o, o.shape) print("MEVO loss", o, o.shape)
del o del o
cur_mem = round(torch.cuda.memory_allocated() / 1024 / 1024) cur_mem = round(torch.cuda.memory_allocated() / 1024 / 1024)
...@@ -70,22 +75,22 @@ def test_mevo(): ...@@ -70,22 +75,22 @@ def test_mevo():
mem = round(torch.cuda.max_memory_allocated() / 1024 / 1024) mem = round(torch.cuda.max_memory_allocated() / 1024 / 1024)
print("after moving input and its grad, cur and peak mem for tiled fwd+bwd =", cur_mem, mem) print("after moving input and its grad, cur and peak mem for tiled fwd+bwd =", cur_mem, mem)
print(weight.grad.norm(), weight.grad) print("MEVO grad norm and grad", weight.grad.norm(), weight.grad)
g1 = weight.grad.clone() g1 = weight.grad.clone()
weight.grad = None weight.grad = None
input = input_data.cuda().requires_grad_(True) input = input_data.cuda().requires_grad_(True)
refk = BaselineSoftmaxNllLoss(weight) refk = BaselineSoftmaxNllLoss(weight, scale=lmcl_scale)
o = refk(input, target) o = refk(input, target)
o.backward() o.backward()
print(o, o.shape) print("Reference loss", o, o.shape)
del o del o
print(weight.grad.norm(), weight.grad) print("Reference grad norm and grad", weight.grad.norm(), weight.grad)
g2 = weight.grad.clone() g2 = weight.grad.clone()
input_grad2 = input.grad.cpu() input_grad2 = input.grad.cpu()
# Print the diff. We use .cuda() since in 1.7 and 1.8, min() and max() are not # Print the diff. We use .cuda() since in torch 1.7 and 1.8, min() and max() are not
# implemented for cpu float16. # implemented for cpu float16. The diff should in general be below 0.01 in magnitude.
diff = g1 - g2 diff = g1 - g2
print("weight grad diff", diff.cuda().min(), diff.cuda().max()) print("weight grad diff", diff.cuda().min(), diff.cuda().max())
diff = input_grad1 - input_grad2 diff = input_grad1 - input_grad2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment