Commit 31c962dc authored by zhe chen's avatar zhe chen
Browse files

Update huggingface model code

parent 773e90c1
......@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor):
return self.model.forward_features(tensor)
def forward(self, pixel_values):
return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel):
......@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor, labels=None):
outputs = self.model.forward(tensor)
def forward(self, pixel_values, labels=None):
outputs = self.model.forward(pixel_values)
if labels is not None:
logits = outputs['logits']
......
......@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor):
return self.model.forward_features(tensor)
def forward(self, pixel_values):
return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel):
......@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor, labels=None):
outputs = self.model.forward(tensor)
def forward(self, pixel_values, labels=None):
outputs = self.model.forward(pixel_values)
if labels is not None:
logits = outputs['logits']
......
......@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor):
return self.model.forward_features(tensor)
def forward(self, pixel_values):
return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel):
......@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor, labels=None):
outputs = self.model.forward(tensor)
def forward(self, pixel_values, labels=None):
outputs = self.model.forward(pixel_values)
if labels is not None:
logits = outputs['logits']
......
......@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor):
return self.model.forward_features(tensor)
def forward(self, pixel_values):
return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel):
......@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor, labels=None):
outputs = self.model.forward(tensor)
def forward(self, pixel_values, labels=None):
outputs = self.model.forward(pixel_values)
if labels is not None:
logits = outputs['logits']
......
......@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor):
return self.model.forward_features(tensor)
def forward(self, pixel_values):
return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel):
......@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor, labels=None):
outputs = self.model.forward(tensor)
def forward(self, pixel_values, labels=None):
outputs = self.model.forward(pixel_values)
if labels is not None:
logits = outputs['logits']
......
......@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor):
return self.model.forward_features(tensor)
def forward(self, pixel_values):
return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel):
......@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor, labels=None):
outputs = self.model.forward(tensor)
def forward(self, pixel_values, labels=None):
outputs = self.model.forward(pixel_values)
if labels is not None:
logits = outputs['logits']
......
......@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor):
return self.model.forward_features(tensor)
def forward(self, pixel_values):
return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel):
......@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor, labels=None):
outputs = self.model.forward(tensor)
def forward(self, pixel_values, labels=None):
outputs = self.model.forward(pixel_values)
if labels is not None:
logits = outputs['logits']
......
......@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor):
return self.model.forward_features(tensor)
def forward(self, pixel_values):
return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel):
......@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor, labels=None):
outputs = self.model.forward(tensor)
def forward(self, pixel_values, labels=None):
outputs = self.model.forward(pixel_values)
if labels is not None:
logits = outputs['logits']
......
......@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor):
return self.model.forward_features(tensor)
def forward(self, pixel_values):
return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel):
......@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor, labels=None):
outputs = self.model.forward(tensor)
def forward(self, pixel_values, labels=None):
outputs = self.model.forward(pixel_values)
if labels is not None:
logits = outputs['logits']
......
......@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor):
return self.model.forward_features(tensor)
def forward(self, pixel_values):
return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel):
......@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor, labels=None):
outputs = self.model.forward(tensor)
def forward(self, pixel_values, labels=None):
outputs = self.model.forward(pixel_values)
if labels is not None:
logits = outputs['logits']
......
......@@ -853,8 +853,8 @@ class InternImageModel(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor):
return self.model.forward_features(tensor)
def forward(self, pixel_values):
return self.model.forward_features(pixel_values)
class InternImageModelForImageClassification(PreTrainedModel):
......@@ -888,8 +888,8 @@ class InternImageModelForImageClassification(PreTrainedModel):
remove_center=config.remove_center, # for InternImage-H/G
)
def forward(self, tensor, labels=None):
outputs = self.model.forward(tensor)
def forward(self, pixel_values, labels=None):
outputs = self.model.forward(pixel_values)
if labels is not None:
logits = outputs['logits']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment