Unverified Commit cb47293e authored by Shanmugam Ramasamy's avatar Shanmugam Ramasamy Committed by GitHub
Browse files

Patching clip model to create mask tensor on the device (#22711)



* Patching clip model to create mask tensor on the device

* Addressing PR's comments

* Addressing PR's comments

* Addressing PR's comments

---------
Co-authored-by: default avatarShanmugam Ramasamy <shanmugamr@shanmugamr-mlt.client.nvidia.com>
parent 2da73f63
...@@ -714,8 +714,8 @@ class CLIPTextTransformer(nn.Module): ...@@ -714,8 +714,8 @@ class CLIPTextTransformer(nn.Module):
bsz, seq_len = input_shape bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here. # CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( causal_attention_mask = self._build_causal_attention_mask(
hidden_states.device bsz, seq_len, hidden_states.dtype, device=hidden_states.device
) )
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
...@@ -752,11 +752,11 @@ class CLIPTextTransformer(nn.Module): ...@@ -752,11 +752,11 @@ class CLIPTextTransformer(nn.Module):
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
) )
def _build_causal_attention_mask(self, bsz, seq_len, dtype): def _build_causal_attention_mask(self, bsz, seq_len, dtype, device=None):
# lazily create causal attention mask, with full attention between the vision tokens # lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf # pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=device)
mask.fill_(torch.tensor(torch.finfo(dtype).min)) mask.fill_(torch.finfo(dtype).min)
mask.triu_(1) # zero out the lower diagonal mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask mask = mask.unsqueeze(1) # expand mask
return mask return mask
......
...@@ -726,8 +726,8 @@ class CLIPSegTextTransformer(nn.Module): ...@@ -726,8 +726,8 @@ class CLIPSegTextTransformer(nn.Module):
bsz, seq_len = input_shape bsz, seq_len = input_shape
# CLIPSeg's text model uses causal mask, prepare it here. # CLIPSeg's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIPSeg/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clipseg/model.py#L324 # https://github.com/openai/CLIPSeg/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clipseg/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( causal_attention_mask = self._build_causal_attention_mask(
hidden_states.device bsz, seq_len, hidden_states.dtype, device=hidden_states.device
) )
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
...@@ -764,11 +764,11 @@ class CLIPSegTextTransformer(nn.Module): ...@@ -764,11 +764,11 @@ class CLIPSegTextTransformer(nn.Module):
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
) )
def _build_causal_attention_mask(self, bsz, seq_len, dtype): def _build_causal_attention_mask(self, bsz, seq_len, dtype, device=None):
# lazily create causal attention mask, with full attention between the vision tokens # lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf # pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=device)
mask.fill_(torch.tensor(torch.finfo(dtype).min)) mask.fill_(torch.finfo(dtype).min)
mask.triu_(1) # zero out the lower diagonal mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask mask = mask.unsqueeze(1) # expand mask
return mask return mask
......
...@@ -1108,8 +1108,8 @@ class GroupViTTextTransformer(nn.Module): ...@@ -1108,8 +1108,8 @@ class GroupViTTextTransformer(nn.Module):
bsz, seq_len = input_shape bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here. # CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( causal_attention_mask = self._build_causal_attention_mask(
hidden_states.device bsz, seq_len, hidden_states.dtype, device=hidden_states.device
) )
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
...@@ -1146,11 +1146,11 @@ class GroupViTTextTransformer(nn.Module): ...@@ -1146,11 +1146,11 @@ class GroupViTTextTransformer(nn.Module):
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
) )
def _build_causal_attention_mask(self, bsz, seq_len, dtype): def _build_causal_attention_mask(self, bsz, seq_len, dtype, device=None):
# lazily create causal attention mask, with full attention between the vision tokens # lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf # pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=device)
mask.fill_(torch.tensor(torch.finfo(dtype).min)) mask.fill_(torch.finfo(dtype).min)
mask.triu_(1) # zero out the lower diagonal mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask mask = mask.unsqueeze(1) # expand mask
return mask return mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment