Unverified Commit 89e1ae5f authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] add CosFace paper's LMCL to MEVO (#916)



* [feat] add CosFace paper's LMCL to MEVO

- added baseline algorithm to the reference kernel
- added MEVO version of LMCL
- added unit test to verify it is correct with respect to the reference as well as its memory usage

* updated changelog
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 8ba649e1
......@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.4.6] - TBD
### Added
- CosFace's LMCL is added to MEVO. This is a loss function that is suitable
for large number of prediction target classes. It added normalization,
class separation margins and feature vector scaling to the standard
output projection + cross-entropy loss. MEVO supported this with its
memory saving techniques so that peak GPU memory is much reduced. [#916]
## [0.4.5] - 2022-01-14
......
......@@ -68,6 +68,10 @@ def get_data(
class BaselineSoftmax(nn.Module):
"""Baseline softmax that does an output linear projection and a softmax.
We also support LMCL (Large Margin Cosine Loss) from the CosFace paper. See
more detailed comment in the MEVO class below.
This is intended to be used with an embedding layer with shared weights.
Args:
......@@ -77,10 +81,23 @@ class BaselineSoftmax(nn.Module):
Unused. It is here to make kernel init easier with MEVO.
log_softmax (bool):
If True, use log_softmax instead of softmax.
margin (float):
Used in LMCL (when scale != None). See MEVO comments for
more details.
scale (Optional[float]):
Used in LMCL. If scale is None, LMCL is turned off. See
MEVO comments for more details.
"""
def __init__(self, proj_weight: nn.Parameter, tile_factor: int = 0, log_softmax: bool = True):
def __init__(
self,
proj_weight: nn.Parameter,
tile_factor: int = 0,
log_softmax: bool = True,
margin: float = 0.35,
scale: Optional[float] = None,
):
super().__init__()
out_dim, in_dim = proj_weight.shape
assert "cuda" in str(proj_weight.device), "weight should be on GPU"
......@@ -92,6 +109,27 @@ class BaselineSoftmax(nn.Module):
assert self.fc.weight.dtype in [torch.float16, torch.float32], self.fc.weight.dtype
self.fp16 = self.fc.weight.dtype == torch.float16
self.log_softmax = log_softmax
self.margin = margin
self.scale = scale
def lmcl_pre_softmax(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# normalize feature and fc layer before multiplication
# n: number of features (tokens)
# k: number of classes (vocab size)
# c: hidden dimension (d_model)
x = F.normalize(input, dim=1)
w = F.normalize(self.fc.weight, dim=1)
logits = torch.einsum("nc,kc->nk", x, w)
# add margin
row_ind = torch.arange(x.shape[0], dtype=torch.long).to(x.device)
col_ind = target
logits[row_ind, col_ind] -= self.margin
# add scale
logits *= self.scale
return logits
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore
"""Forward function that computes softmax output with the input and target."""
......@@ -100,6 +138,9 @@ class BaselineSoftmax(nn.Module):
input, target = _reshape_inputs(input, target)
if self.fp16:
assert input.dtype == torch.float16
if self.scale is not None:
x = self.lmcl_pre_softmax(input, target)
else:
x = self.fc(input)
# Note that we do softmax in FP32, which is important for numerical stability.
if self.log_softmax:
......@@ -119,8 +160,15 @@ class BaselineSoftmaxNllLoss(BaselineSoftmax):
This class is used for testing and benchmarking.
"""
def __init__(self, proj_weight: nn.Parameter, tile_factor: int = 0, log_softmax: bool = True):
super().__init__(proj_weight, tile_factor, log_softmax)
def __init__(
self,
proj_weight: nn.Parameter,
tile_factor: int = 0,
log_softmax: bool = True,
margin: float = 0.35,
scale: Optional[float] = None,
):
super().__init__(proj_weight, tile_factor, log_softmax, margin, scale)
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore
"""Forward that directly compute the loss."""
......@@ -131,17 +179,48 @@ class BaselineSoftmaxNllLoss(BaselineSoftmax):
return F.nll_loss(x, target, reduction="sum")
def lmcl_matmul(
i: torch.Tensor, w: torch.Tensor, tgt: torch.Tensor, w_idx: int, margin: float, scale: Optional[float]
) -> torch.Tensor:
"""LMCL variation of matmul with normalization, margin and scale."""
# normalize and matmul
logits = torch.matmul(F.normalize(i, dim=1), F.normalize(w, dim=1).T)
# add margin using a mask since tgt might be out of the the weight split's range.
mask = torch.arange(w_idx * w.shape[0], (w_idx + 1) * w.shape[0], dtype=torch.long, device=i.device).expand(
i.shape[0], -1
)
logits[mask == tgt.reshape(-1, 1)] -= margin
# add scale
logits *= scale
return logits
class GetMaxFunction(torch.autograd.Function):
"""Custom checkpointed function to get max-per-token from an input and a weight"""
@staticmethod
def get_max(i: torch.Tensor, w: torch.Tensor, full_precision: bool) -> torch.Tensor:
def get_max(
i: torch.Tensor,
w: torch.Tensor,
tgt: torch.Tensor,
w_idx: int,
full_precision: bool,
margin: float,
scale: Optional[float],
) -> torch.Tensor:
"""
Throughout this code:
i: input data with shape = (split-of-tokens, d_model)
w: weight data with shape = (split-of-vocabs, d_model)
tgt: target prediction data with shape = (split-of-tokens,)
"""
if scale is not None:
_m = lmcl_matmul(i, w, tgt, w_idx, margin, scale)
else:
_m = torch.matmul(i, w.T)
if full_precision:
_m = _m.float()
......@@ -153,6 +232,7 @@ class GetMaxFunction(torch.autograd.Function):
ctx: Any,
i: torch.Tensor,
w: torch.Tensor,
tgt: torch.Tensor,
kernel_obj: "MemoryEfficientVocabOutput",
w_idx: int,
w_split_size: int,
......@@ -161,7 +241,7 @@ class GetMaxFunction(torch.autograd.Function):
"""Forward function that computes the max, without saving activations."""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG max fwd")
ctx.save_for_backward(i, w)
ctx.save_for_backward(i, w, tgt)
ctx.kernel_obj = kernel_obj
ctx.w_idx = w_idx
ctx.w_split_size = w_split_size
......@@ -171,7 +251,7 @@ class GetMaxFunction(torch.autograd.Function):
# The activations will be recomputed in backward below and freed
# immediately after use. This saves the overall GPU peak memory of this layer.
with torch.no_grad():
return GetMaxFunction.get_max(i, w, kernel_obj.fp_max)
return GetMaxFunction.get_max(i, w, tgt, w_idx, kernel_obj.fp_max, kernel_obj.margin, kernel_obj.scale)
@staticmethod
def backward(ctx: Any, *args: Any) -> Any:
......@@ -186,7 +266,7 @@ class GetMaxFunction(torch.autograd.Function):
assert ctx.kernel_obj.proj_weight.grad is not None
# Get saved i and w.
i, w = ctx.saved_tensors
i, w, tgt = ctx.saved_tensors
assert i.requires_grad
assert w.requires_grad
# We use ``detach()'' to ensure the backward call below does not
......@@ -199,7 +279,9 @@ class GetMaxFunction(torch.autograd.Function):
# Forward + backward again.
with torch.enable_grad():
# This saves the activations.
maxs = GetMaxFunction.get_max(i, w, ctx.kernel_obj.fp_max)
maxs = GetMaxFunction.get_max(
i, w, tgt, ctx.w_idx, ctx.kernel_obj.fp_max, ctx.kernel_obj.margin, ctx.kernel_obj.scale
)
# This will use the activations and free them immediately.
torch.autograd.backward(maxs, *args)
......@@ -208,14 +290,26 @@ class GetMaxFunction(torch.autograd.Function):
with torch.no_grad():
grads = torch.split(ctx.kernel_obj.proj_weight.grad, ctx.w_split_size)
grads[ctx.w_idx].add_(w.grad)
return i.grad, None, None, None, None, None
return i.grad, None, None, None, None, None, None
class GetSumFunction(torch.autograd.Function):
"""Custom checkpointed function to get sum-per-token from an input and a weight."""
@staticmethod
def get_sum(i: torch.Tensor, w: torch.Tensor, maxs: torch.Tensor, full_precision: bool) -> torch.Tensor:
def get_sum(
i: torch.Tensor,
w: torch.Tensor,
tgt: torch.Tensor,
maxs: torch.Tensor,
w_idx: int,
full_precision: bool,
margin: float,
scale: Optional[float],
) -> torch.Tensor:
if scale is not None:
_s = lmcl_matmul(i, w, tgt, w_idx, margin, scale)
else:
_s = torch.matmul(i, w.T)
if full_precision:
_s = _s.float()
......@@ -227,6 +321,7 @@ class GetSumFunction(torch.autograd.Function):
ctx: Any,
i: torch.Tensor,
w: torch.Tensor,
tgt: torch.Tensor,
maxs: torch.Tensor,
kernel_obj: "MemoryEfficientVocabOutput",
w_idx: int,
......@@ -236,13 +331,15 @@ class GetSumFunction(torch.autograd.Function):
"""Forward function that computes the sum, without saving activations."""
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print("DEBUG sum fwd")
ctx.save_for_backward(i, w, maxs)
ctx.save_for_backward(i, w, tgt, maxs)
ctx.kernel_obj = kernel_obj
ctx.w_idx = w_idx
ctx.w_split_size = w_split_size
assert split_dim == 0
with torch.no_grad():
return GetSumFunction.get_sum(i, w, maxs, kernel_obj.fp_sum)
return GetSumFunction.get_sum(
i, w, tgt, maxs, w_idx, kernel_obj.fp_sum, kernel_obj.margin, kernel_obj.scale
)
@staticmethod
def backward(ctx: Any, *args: Any) -> Any:
......@@ -257,7 +354,7 @@ class GetSumFunction(torch.autograd.Function):
assert ctx.kernel_obj.proj_weight.grad is not None
# Get saved i, w, and maxs.
i, w, maxs = ctx.saved_tensors
i, w, tgt, maxs = ctx.saved_tensors
assert i.requires_grad
assert w.requires_grad
assert maxs.requires_grad
......@@ -267,7 +364,9 @@ class GetSumFunction(torch.autograd.Function):
# Forward + backward again.
with torch.enable_grad():
sums = GetSumFunction.get_sum(i, w, maxs, ctx.kernel_obj.fp_sum)
sums = GetSumFunction.get_sum(
i, w, tgt, maxs, ctx.w_idx, ctx.kernel_obj.fp_sum, ctx.kernel_obj.margin, ctx.kernel_obj.scale
)
torch.autograd.backward(sums, *args)
# Accumulate the grads.
......@@ -275,22 +374,35 @@ class GetSumFunction(torch.autograd.Function):
with torch.no_grad():
grads = torch.split(ctx.kernel_obj.proj_weight.grad, ctx.w_split_size)
grads[ctx.w_idx].add_(w.grad)
return i.grad, None, maxs.grad, None, None, None, None
return i.grad, None, None, maxs.grad, None, None, None, None
class TargetScoreFunction(torch.autograd.Function):
"""Custom checkpointed function to compute the target score."""
@staticmethod
def get_target_score(i: torch.Tensor, w: torch.Tensor, target: torch.Tensor, full_precision: bool) -> torch.Tensor:
def get_target_score(
i: torch.Tensor,
w: torch.Tensor,
target: torch.Tensor,
full_precision: bool,
margin: float,
scale: Optional[float],
) -> torch.Tensor:
tokens, d_model = i.shape
assert d_model == w.shape[1]
tw = w.gather(dim=0, index=target.reshape(target.shape[0], 1).expand(target.shape[0], d_model))
assert tw.shape == (tokens, d_model)
if scale is not None:
target_score = F.normalize(i, dim=1) * F.normalize(tw, dim=1)
else:
target_score = i * tw
if full_precision:
target_score = target_score.float()
target_score = target_score.sum(dim=1) # sum into target scores with shape (tokens,)
if scale is not None:
target_score -= margin
target_score *= scale
return target_score
@staticmethod
......@@ -303,7 +415,9 @@ class TargetScoreFunction(torch.autograd.Function):
ctx.save_for_backward(i, w, target)
ctx.kernel_obj = kernel_obj
with torch.no_grad():
x = TargetScoreFunction.get_target_score(i, w, target, kernel_obj.fp_target)
x = TargetScoreFunction.get_target_score(
i, w, target, kernel_obj.fp_target, kernel_obj.margin, kernel_obj.scale
)
return x
@staticmethod
......@@ -319,7 +433,9 @@ class TargetScoreFunction(torch.autograd.Function):
i = i.detach().requires_grad_(True)
w = w.detach().requires_grad_(True)
with torch.enable_grad():
scores = TargetScoreFunction.get_target_score(i, w, target, ctx.kernel_obj.fp_target)
scores = TargetScoreFunction.get_target_score(
i, w, target, ctx.kernel_obj.fp_target, ctx.kernel_obj.margin, ctx.kernel_obj.scale
)
torch.autograd.backward(scores, *args)
if ctx.kernel_obj.proj_weight.grad is not None:
# This means we accumulate full grad between iters. Not memory efficient.
......@@ -388,18 +504,55 @@ class BackwardTrigger(nn.Module):
class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
"""Fused fc + softmax + nll_loss in a tiled fashion.
This uses much less memory but is quite a bit slower.
MEVO uses much less memory but is quite a bit slower.
MEVO also implements the LMCL (Large Margin Cosine Loss) function introduced by
highly cited
`CosFace: Large Margin Cosine Loss for Deep Face Recognition [Wang et al.]`_.
.. _`CosFace: Large Margin Cosine Loss for Deep Face Recognition [Wang et al.]`: https://arxiv.org/abs/1801.09414
LMCL can be turned on using the ``margin`` and ``scale`` parameters below. These
hyperparameters most likely require tuning, depending on the number of classes etc.
MEVO LMCL can be suitable for face recognition and image retrieval tasks, esp. when
the number prediction target classes is large. MEVO is slower but can use much
less GPU memory in that case, which enables training with larger batches. We
hope this is helpful but we strongly recommend users (AI researchers
and engineers) to carefully consider their applications of this technology. This
types of technology should not be used by small group of people exclusively to
potentially harm the general public.
Args:
proj_weight (nn.Parameter):
Sharing this weight with an embedding layer.
tile_factor (int):
Number of splits to use on the input sequence and vocab dimensions.
Default: 16
reduction (str):
Reduction OP (sum or mean).
Default: sum
margin (float):
Hyperparameter of the separation margin between classes. See the
appendix of the CosFace paper for a formula on how to compute its
value properly. The default value is unlikely to be suitable in all
cases.
Default: 0.35
scale (Optional[float]):
Hyperparameter of the feature-vector-scaling for LMCL. When not
supplied, LMCL is turned off. See the appendix of the CosFace paper for
a formula on how to compute its value properly.
Default: None
"""
def __init__(self, proj_weight: nn.Parameter, tile_factor: int = 16, reduction: str = "sum"):
def __init__(
self,
proj_weight: nn.Parameter,
tile_factor: int = 16,
reduction: str = "sum",
margin: float = 0.35,
scale: Optional[float] = None,
):
super().__init__()
self.proj_weight = proj_weight
# TODO (Min): these two factors doesn't have to be the same. More tuning can be done.
......@@ -410,12 +563,14 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
self.log_softmax = True
self.reduction = reduction
assert self.reduction in ["sum", "mean"]
self.margin = margin
self.scale = scale
self.trigger = BackwardTrigger(self.proj_weight)
if DEBUG and dist.is_initialized() and dist.get_rank() == 0:
print(
f"DEBUG cfg tf_in={self.tf_in} tf_w={self.tf_w} fp_max={self.fp_max} "
f"fp_sum={self.fp_sum} fp_target={self.fp_target} log_softmax={self.log_softmax} "
f"reduction={self.reduction}"
f"reduction={self.reduction} margin={self.margin} scale={self.scale}"
)
def get_target_nlprob(
......@@ -432,6 +587,8 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
def eval_forward(self, input: torch.Tensor) -> torch.Tensor:
"""Eval time forward that doesn't fuse the softmax and NLL Loss kernels."""
# Margin, scaling and normalization of LMCL does not apply to eval time as far as
# I can tell. Therefore, we just do a matmul like the standard output layer.
return torch.matmul(input, self.proj_weight.T)
def forward(self, input: torch.Tensor, target: Optional[torch.Tensor]) -> torch.Tensor: # type: ignore
......@@ -449,8 +606,10 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
input, target = _reshape_inputs(input, target)
tokens, d_model = input.shape
(t2,) = target.shape
vocab, d2 = self.proj_weight.shape
assert d_model == d2
assert d_model == d2, f"incorrect shape {d_model} vs {d2}"
assert tokens == t2, f"incorrect shape {tokens} vs {t2}"
split_dim = 0
input_split_size = _next_power_of_2_or_max(tokens // self.tf_in, tokens)
weight_split_size = _next_power_of_2_or_max(vocab // self.tf_w, vocab)
......@@ -458,12 +617,16 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
weight = self.trigger()
weights = torch.split(weight, weight_split_size, split_dim)
targets = tuple([torch.Tensor()] * len(inputs))
if self.scale is not None:
targets = torch.split(target, input_split_size, split_dim)
# Get maxs
maxs = []
for i in inputs:
for i, tgt in zip(inputs, targets):
m = None # max with (tokens_tile,) shape
for w_idx, w in enumerate(weights):
_m = GetMaxFunction.apply(i, w, self, w_idx, weight_split_size, split_dim)
_m = GetMaxFunction.apply(i, w, tgt, self, w_idx, weight_split_size, split_dim)
if m is None:
m = _m
else:
......@@ -475,10 +638,10 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
# Get sums.
sums = []
for idx, i in enumerate(inputs):
for i, tgt, debase_max in zip(inputs, targets, maxs):
s = None # sum with (tokens_tile,) shape
for w_idx, w in enumerate(weights):
_s = GetSumFunction.apply(i, w, maxs[idx], self, w_idx, weight_split_size, split_dim)
_s = GetSumFunction.apply(i, w, tgt, debase_max, self, w_idx, weight_split_size, split_dim)
if s is None:
s = _s
else:
......
......@@ -38,10 +38,15 @@ def test_mevo_eval():
assert out.shape == (1, 5, 4)
# Note for the lmcl_scale, overly large value, like 64 for small shape input
# will cause inf/nan in mevo. Larger scale value is only needed for large shape inputs.
@skip_if_no_cuda
def test_mevo():
"""Test the MEVO kernel by itself."""
@pytest.mark.parametrize("lmcl_scale", [None, 8])
def test_mevo(lmcl_scale):
"""Test the MEVO kernel in a single process (no DDP/FSDP)."""
# Set seed and reset peak mem so that peak measure below is correct.
torch.random.manual_seed(os.getpid())
torch.cuda.reset_peak_memory_stats()
shape = ((5, 3), (3, 7))
# Turn on large data for local testing.
large = False
......@@ -50,11 +55,11 @@ def test_mevo():
print("\nshapes are", shape)
input, weight, target = get_data(shape, dtype=torch.float16)
k = MEVO(weight, tile_factor=16)
k = MEVO(weight, tile_factor=16, scale=lmcl_scale)
o = k(input, target)
o.backward()
print(o, o.shape)
print("MEVO loss", o, o.shape)
del o
cur_mem = round(torch.cuda.memory_allocated() / 1024 / 1024)
......@@ -70,22 +75,22 @@ def test_mevo():
mem = round(torch.cuda.max_memory_allocated() / 1024 / 1024)
print("after moving input and its grad, cur and peak mem for tiled fwd+bwd =", cur_mem, mem)
print(weight.grad.norm(), weight.grad)
print("MEVO grad norm and grad", weight.grad.norm(), weight.grad)
g1 = weight.grad.clone()
weight.grad = None
input = input_data.cuda().requires_grad_(True)
refk = BaselineSoftmaxNllLoss(weight)
refk = BaselineSoftmaxNllLoss(weight, scale=lmcl_scale)
o = refk(input, target)
o.backward()
print(o, o.shape)
print("Reference loss", o, o.shape)
del o
print(weight.grad.norm(), weight.grad)
print("Reference grad norm and grad", weight.grad.norm(), weight.grad)
g2 = weight.grad.clone()
input_grad2 = input.grad.cpu()
# Print the diff. We use .cuda() since in 1.7 and 1.8, min() and max() are not
# implemented for cpu float16.
# Print the diff. We use .cuda() since in torch 1.7 and 1.8, min() and max() are not
# implemented for cpu float16. The diff should in general be below 0.01 in magnitude.
diff = g1 - g2
print("weight grad diff", diff.cuda().min(), diff.cuda().max())
diff = input_grad1 - input_grad2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment