Unverified Commit e563983d authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

fix cpu and xpu issue (#2116)


Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent 9e2fdf57
...@@ -768,6 +768,9 @@ class FlashCausalLM(Model): ...@@ -768,6 +768,9 @@ class FlashCausalLM(Model):
empty_cache() empty_cache()
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "ipex" and device.type == "xpu":
x = 1
else:
x = BLOCK_SIZE // element_size x = BLOCK_SIZE // element_size
if SYSTEM == "ipex" and device == torch.device("cpu"): if SYSTEM == "ipex" and device == torch.device("cpu"):
......
...@@ -37,9 +37,10 @@ class FlashGPT2(FlashCausalLM): ...@@ -37,9 +37,10 @@ class FlashGPT2(FlashCausalLM):
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available(): if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashGPT2 is only available on GPU") raise NotImplementedError("FlashGPT2 is only available on GPU")
......
...@@ -37,9 +37,10 @@ class FlashLlama(FlashCausalLM): ...@@ -37,9 +37,10 @@ class FlashLlama(FlashCausalLM):
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available(): if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")
......
...@@ -41,9 +41,10 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -41,9 +41,10 @@ class BaseFlashMistral(FlashCausalLM):
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available(): if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashMistral is only available on GPU") raise NotImplementedError("FlashMistral is only available on GPU")
......
...@@ -36,9 +36,10 @@ class FlashNeoXSharded(FlashCausalLM): ...@@ -36,9 +36,10 @@ class FlashNeoXSharded(FlashCausalLM):
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available(): if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashNeoX is only available on GPU") raise NotImplementedError("FlashNeoX is only available on GPU")
......
...@@ -37,9 +37,10 @@ class FlashRWSharded(FlashCausalLM): ...@@ -37,9 +37,10 @@ class FlashRWSharded(FlashCausalLM):
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available(): if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashRW is only available on GPU") raise NotImplementedError("FlashRW is only available on GPU")
......
...@@ -40,9 +40,10 @@ class FlashSantacoderSharded(FlashCausalLM): ...@@ -40,9 +40,10 @@ class FlashSantacoderSharded(FlashCausalLM):
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available(): if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU") raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment