Unverified Commit 2a939f20 authored by Glen Taggart's avatar Glen Taggart Committed by GitHub
Browse files

Substantially reduce memory usage in _update_causal_mask for large batches by...


Substantially reduce memory usage in _update_causal_mask for large batches by using .expand instead of .repeat [needs tests+sanity check] (#29413)

* try to fix gemma mem use

* fix: handle attention mask dim==2 case

* remove logits=logits.float()

* clean up + add llama

* apply formatting

* readability edit: swap order of items being multiplied

* revert change unrelated to PR

* revert black autoformat

* switch to one .to

* Accept style edits
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 965cf677
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Gemma model.""" """ PyTorch Gemma model."""
import math import math
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -971,10 +972,11 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -971,10 +972,11 @@ class GemmaModel(GemmaPreTrainedModel):
# We use the current dtype to avoid any overflows # We use the current dtype to avoid any overflows
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
causal_mask = causal_mask.to(dtype=dtype, device=device) causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
if attention_mask is not None and attention_mask.dim() == 2: if attention_mask is not None and attention_mask.dim() == 2:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1] mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch LLaMA model.""" """PyTorch LLaMA model."""
import math import math
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -1083,10 +1084,10 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1083,10 +1084,10 @@ class LlamaModel(LlamaPreTrainedModel):
# We use the current dtype to avoid any overflows # We use the current dtype to avoid any overflows
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
causal_mask = causal_mask.to(dtype=dtype, device=device)
if attention_mask is not None and attention_mask.dim() == 2: if attention_mask is not None and attention_mask.dim() == 2:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1] mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment