Unverified Commit 86d9baed authored by Chen Shengzhi's avatar Chen Shengzhi Committed by GitHub
Browse files

[Fix] Fix errors when using the device except cuda. (#4455)

parent 21d485f8
...@@ -227,7 +227,8 @@ class MHATokenToKVPool(KVCache): ...@@ -227,7 +227,8 @@ class MHATokenToKVPool(KVCache):
self.layer_transfer_counter = None self.layer_transfer_counter = None
self.capture_mode = False self.capture_mode = False
self.alt_stream = torch.cuda.Stream() self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream()
k_size, v_size = self.get_kv_size_bytes() k_size, v_size = self.get_kv_size_bytes()
logger.info( logger.info(
...@@ -339,11 +340,12 @@ class MHATokenToKVPool(KVCache): ...@@ -339,11 +340,12 @@ class MHATokenToKVPool(KVCache):
cache_v = cache_v.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype)
if self.capture_mode and cache_k.shape[0] < 4: if self.capture_mode and cache_k.shape[0] < 4:
self.alt_stream.wait_stream(torch.cuda.current_stream()) current_stream = self.device_module.current_stream()
with torch.cuda.stream(self.alt_stream): self.alt_stream.wait_stream(current_stream)
with self.device_module.stream(self.alt_stream):
self.k_buffer[layer_id][loc] = cache_k self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v self.v_buffer[layer_id][loc] = cache_v
torch.cuda.current_stream().wait_stream(self.alt_stream) current_stream.wait_stream(self.alt_stream)
else: else:
self.k_buffer[layer_id][loc] = cache_k self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v self.v_buffer[layer_id][loc] = cache_v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment