#!/usr/bin/awk -f
BEGIN {
    OFS="\t"
    if (ARGV[1] == "--header") {
        print_header()
        ARGV[1]=""              # tell awk that '--header' isn't a file name
    }
}

function print_header() {
    print("file", "gpus", "batch", "total", "converge", "init", "epoch1", "epoch_avg", "[epoch", "eval_time", "mAP]*")
}


function get_mll_string_val(line, key) {
    myregex= "\"" key "\": \"([^\"]*)\""
    match(line, myregex, result_array)
    return result_array[1]
}

function get_mll_int_val(line, key) {
    myregex= "\"" key "\": ([0-9]*)"
    match(line, myregex, result_array)
    return result_array[1]
}

function get_mll_float_val(line, key) {
    myregex= "\"" key "\": ([0-9.e+-]*)"
    match(line, myregex, result_array)
    return result_array[1]
}

function get_mll_time(line) {
    return get_mll_int_val(line, "time_ms")/1000
}

function get_mll_epoch_num(line) {
    return get_mll_int_val(line, "epoch_num")
}

BEGINFILE {
    stop_status = "notdone"
    last_eval_epoch = -1
    run_start_time = 0
    run_stop_time = 0
    avg_epoch_time = -1
    init_time = -1
    training_time = 0
    last_eval_time = 0
    delete eval_time
    delete eval_acc
    delete time_at_epoch
    ranks = 0
    global_batch = -1
}

# make sure all the relevant lines have the fields in the positions we expect
# them (sometimes the parallel output causes multiple (or none) "0: " at the
# beginning instead of the single one expected, just make it none)
/:::MLL/ {
    sub(/^.*:::MLL/, ":::MLL")
}

/:::MLL.*"key": "init_start"/ {
    ranks = ranks+1
}

/:::MLL.*"key": "global_batch_size"/ {
    global_batch = get_mll_int_val($0, "value")
}

/:::MLL.*"key": "epoch_start"/ {
    epoch_num = get_mll_epoch_num($0)
    epoch_start_time = get_mll_time($0)
    time_at_epoch[epoch_num] = epoch_start_time
    if (epoch_num == 1) {
        init_time=epoch_start_time-run_start_time
    }
}

/:::MLL.*"key": "epoch_stop"/ {
    epoch_num = get_mll_epoch_num($0)
    current_time = get_mll_time($0)
    training_time = training_time + (current_time-time_at_epoch[epoch_num])
    if (epoch_num > 1 && epoch_num <= 39) {
        avg_epoch_time = (current_time - time_at_epoch[2])/(epoch_num-1)
    }
}

/:::MLL.*"key": "eval_start"/ {
    epoch_num = get_mll_epoch_num($0)
    current_time = get_mll_time($0)
    eval_time[epoch_num] = current_time
    last_eval_time = current_time
}

/:::MLL.*"key": "eval_accuracy"/ {
    eval_acc[get_mll_epoch_num($0)] = get_mll_float_val($0, "value")
}

/:::MLL.*"key": "eval_stop"/ {
    epoch_num = get_mll_epoch_num($0)
    eval_time[epoch_num] = get_mll_time($0) - eval_time[epoch_num]
    last_eval_epoch = epoch_num
}


/:::MLL.*"key": "run_start"/ {
    run_start_time=get_mll_time($0)
}

function printall(fname, total_time, last_eval_epoch, init_time, avg_epoch_time, eval_time, eval_acc) {
    if (ranks > 0) {
        local_batch=global_batch/ranks
    } else {
        local_batch = -1
    }
    printf("%s\t%d\t%d\t%.2f\t%s\t%.4f\t%.4f\t%.4f\t%.2f\t%.2f", fname, ranks, local_batch, total_time, last_eval_epoch, init_time, time_at_epoch[2]-time_at_epoch[1], avg_epoch_time, training_time, last_eval_time)
    for (i in eval_time) {
        printf("\t%d\t%.4f\t%.4f", i, eval_time[i], eval_acc[i])
    }
    printf("\n")
}

/:::MLL.*"key": "run_stop"/ {
    stop_status = get_mll_string_val($0, "status")
    run_stop_time = get_mll_time($0)
    if (last_eval_time > 0) {
        last_eval_time = run_stop_time-last_eval_time
    }
    if (stop_status == "success") {
        stop_status = last_eval_epoch
    }
}

ENDFILE {
    printall(FILENAME, run_stop_time-run_start_time, stop_status, init_time, avg_epoch_time, eval_time, eval_acc)
}
