Unverified Commit 63700628 authored by Alvaro Moran's avatar Alvaro Moran Committed by GitHub
Browse files

feat(cache): StaticCache uses index_copy_ to avoid useless copy (#31857)

* feat(cache): StaticCache uses index_copy_ to avoid useless copy

Using index_copy_ allows for explicit in-place change of the tensor.
Some backends (XLA) will otherwise copy the tensor, making the code
slower and using more memory.

Proposed implementation will end up using less memory and on XLA will
result in less compilation, but the change is also quite generic, making
no change whatsoever on CUDA or CPU backend.

* feat(cache): SlidingWindowCache uses index_copy_ to avoid useless copy

Applying the same change done in StaticCache.

* fix(cache): fallback of index_copy_ when not implemented

* fix(cache): in index_copy_ ensure tensors are on same device

* [run slow] llama

* fix(cache): add move of cache_position to same device in SlidingWindowCache

* Revert "[run slow] llama"

This reverts commit 02608dd14253ccd464e31c108e0cd94364f0e8b9.
parent a009fbda
...@@ -862,8 +862,18 @@ class StaticCache(Cache): ...@@ -862,8 +862,18 @@ class StaticCache(Cache):
k_out.copy_(key_states) k_out.copy_(key_states)
v_out.copy_(value_states) v_out.copy_(value_states)
else: else:
k_out[:, :, cache_position] = key_states # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
v_out[:, :, cache_position] = value_states # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
# operation, that avoids copies and uses less memory.
try:
# If using several devices (e.g.: multiple GPUs), we need to ensure everything is on the same one
cache_position.to(device=k_out.device)
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
return k_out, v_out return k_out, v_out
...@@ -958,8 +968,14 @@ class SlidingWindowCache(StaticCache): ...@@ -958,8 +968,14 @@ class SlidingWindowCache(StaticCache):
k_out = k_out[:, :, indices] k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices] v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states try:
v_out[:, :, cache_position] = value_states cache_position.to(device=k_out.device)
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment