Unverified Commit d69a0633 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(server): fix has_position_ids (#395)

Fix #389
parent db2ebe39
...@@ -496,11 +496,6 @@ class CausalLM(Model): ...@@ -496,11 +496,6 @@ class CausalLM(Model):
else: else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
......
import inspect
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -29,6 +30,12 @@ class Model(ABC): ...@@ -29,6 +30,12 @@ class Model(ABC):
self.device = device self.device = device
self.rank = rank self.rank = rank
self.world_size = world_size self.world_size = world_size
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
self.check_initialized() self.check_initialized()
@property @property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment