Unverified Commit adc62fcd authored by topduke's avatar topduke Committed by GitHub
Browse files

Merge branch 'dygraph' into dygraph

parents 8227ad1b a81b88a0
...@@ -26,11 +26,11 @@ from .rec_metric import RecMetric ...@@ -26,11 +26,11 @@ from .rec_metric import RecMetric
from .cls_metric import ClsMetric from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric from .distillation_metric import DistillationMetric
from .table_metric import TableMetric
def build_metric(config): def build_metric(config):
support_dict = [ support_dict = [
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric" "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric"
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -55,6 +55,7 @@ class DetMetric(object): ...@@ -55,6 +55,7 @@ class DetMetric(object):
result = self.evaluator.evaluate_image(gt_info_list, det_info_list) result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
self.results.append(result) self.results.append(result)
def get_metric(self): def get_metric(self):
""" """
return metrics { return metrics {
......
...@@ -24,8 +24,8 @@ from .cls_metric import ClsMetric ...@@ -24,8 +24,8 @@ from .cls_metric import ClsMetric
class DistillationMetric(object): class DistillationMetric(object):
def __init__(self, def __init__(self,
key=None, key=None,
base_metric_name="RecMetric", base_metric_name=None,
main_indicator='acc', main_indicator=None,
**kwargs): **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.key = key self.key = key
...@@ -42,16 +42,13 @@ class DistillationMetric(object): ...@@ -42,16 +42,13 @@ class DistillationMetric(object):
main_indicator=self.main_indicator, **self.kwargs) main_indicator=self.main_indicator, **self.kwargs)
self.metrics[key].reset() self.metrics[key].reset()
def __call__(self, preds, *args, **kwargs): def __call__(self, preds, batch, **kwargs):
assert isinstance(preds, dict) assert isinstance(preds, dict)
if self.metrics is None: if self.metrics is None:
self._init_metrcis(preds) self._init_metrcis(preds)
output = dict() output = dict()
for key in preds: for key in preds:
metric = self.metrics[key].__call__(preds[key], *args, **kwargs) self.metrics[key].__call__(preds[key], batch, **kwargs)
for sub_key in metric:
output["{}_{}".format(key, sub_key)] = metric[sub_key]
return output
def get_metric(self): def get_metric(self):
""" """
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -43,7 +43,7 @@ class ClsHead(nn.Layer): ...@@ -43,7 +43,7 @@ class ClsHead(nn.Layer):
initializer=nn.initializer.Uniform(-stdv, stdv)), initializer=nn.initializer.Uniform(-stdv, stdv)),
bias_attr=ParamAttr(name="fc_0.b_0"), ) bias_attr=ParamAttr(name="fc_0.b_0"), )
def forward(self, x): def forward(self, x, targets=None):
x = self.pool(x) x = self.pool(x)
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]]) x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
x = self.fc(x) x = self.fc(x)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment