Unverified Commit 99f4156f authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Fix cuda reinitialization in a multiprocessing setting (#862)

parent 20d8f47a
...@@ -83,12 +83,18 @@ class TurbomindModelConfig: ...@@ -83,12 +83,18 @@ class TurbomindModelConfig:
return True return True
_WEIGHT_DTYPE_MAP = dict( def _weight_dtype_map(weight_type: str, default=None):
"""get weight dtype map."""
_WEIGHT_DTYPE_MAP = dict(
int4=torch.float16, int4=torch.float16,
fp16=torch.float16, fp16=torch.float16,
fp32=torch.float16, fp32=torch.float16,
bf16=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, bf16=torch.bfloat16
) if torch.cuda.is_bf16_supported() else torch.float16,
)
return _WEIGHT_DTYPE_MAP.get(weight_type, default)
class BaseOutputModel(ABC): class BaseOutputModel(ABC):
...@@ -153,7 +159,7 @@ class BaseOutputModel(ABC): ...@@ -153,7 +159,7 @@ class BaseOutputModel(ABC):
if self.to_file: if self.to_file:
if torch.is_floating_point(param): if torch.is_floating_point(param):
torch_type = _WEIGHT_DTYPE_MAP.get(self.cfg.weight_type, torch_type = _weight_dtype_map(self.cfg.weight_type,
torch.float16) torch.float16)
param = param.to(torch_type) param = param.to(torch_type)
tprint(name, param.shape) tprint(name, param.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment